1use bytes::Bytes;
11use ipfrs_core::error::{Error, Result};
12use quinn::{
13 ClientConfig, Connection, Endpoint, RecvStream, SendStream, ServerConfig, TransportConfig,
14};
15use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use tokio::sync::RwLock;
20
21#[derive(Debug, Clone)]
23pub struct QuicConfig {
24 pub bind_addr: SocketAddr,
26 pub idle_timeout: Duration,
28 pub max_streams: u32,
30 pub enable_0rtt: bool,
32 pub pool_size: usize,
34 pub pool_idle_timeout: Duration,
36 pub max_message_size: usize,
38 pub initial_window: u32,
40 pub max_window: u32,
42}
43
44impl Default for QuicConfig {
45 fn default() -> Self {
46 Self {
47 bind_addr: "0.0.0.0:0".parse().unwrap(),
48 idle_timeout: Duration::from_secs(30),
49 max_streams: 256,
50 enable_0rtt: true,
51 pool_size: 4,
52 pool_idle_timeout: Duration::from_secs(60),
53 max_message_size: 16 * 1024 * 1024, initial_window: 10 * 1024 * 1024, max_window: 100 * 1024 * 1024, }
57 }
58}
59
60struct PooledConnection {
62 connection: Connection,
63 #[allow(dead_code)]
65 created_at: Instant,
66 last_used: Instant,
67 active_streams: u32,
68}
69
70impl PooledConnection {
71 fn new(connection: Connection) -> Self {
72 let now = Instant::now();
73 Self {
74 connection,
75 created_at: now,
76 last_used: now,
77 active_streams: 0,
78 }
79 }
80
81 fn is_healthy(&self) -> bool {
82 self.connection.close_reason().is_none()
83 }
84
85 fn is_idle(&self, timeout: Duration) -> bool {
86 self.last_used.elapsed() > timeout && self.active_streams == 0
87 }
88
89 fn touch(&mut self) {
90 self.last_used = Instant::now();
91 }
92}
93
94struct PeerPool {
96 connections: Vec<PooledConnection>,
97 max_size: usize,
98 idle_timeout: Duration,
99}
100
101impl PeerPool {
102 fn new(max_size: usize, idle_timeout: Duration) -> Self {
103 Self {
104 connections: Vec::with_capacity(max_size),
105 max_size,
106 idle_timeout,
107 }
108 }
109
110 fn get(&mut self) -> Option<&mut PooledConnection> {
112 self.connections.retain(|c| c.is_healthy());
114
115 self.connections.retain(|c| !c.is_idle(self.idle_timeout));
117
118 self.connections
120 .iter_mut()
121 .filter(|c| c.is_healthy())
122 .min_by_key(|c| c.active_streams)
123 }
124
125 fn add(&mut self, connection: Connection) -> bool {
127 if self.connections.len() >= self.max_size {
128 if let Some(pos) = self
130 .connections
131 .iter()
132 .position(|c| c.is_idle(Duration::ZERO))
133 {
134 self.connections.remove(pos);
135 } else {
136 return false;
137 }
138 }
139
140 self.connections.push(PooledConnection::new(connection));
141 true
142 }
143
144 fn connection_count(&self) -> usize {
145 self.connections.len()
146 }
147}
148
149pub struct QuicTransport {
151 endpoint: Endpoint,
153 pools: Arc<RwLock<HashMap<SocketAddr, PeerPool>>>,
155 config: QuicConfig,
157 client_config: ClientConfig,
159}
160
161impl QuicTransport {
162 pub async fn new(config: QuicConfig) -> Result<Self> {
164 let (cert, key) = Self::generate_self_signed_cert()?;
166
167 let server_transport = Self::create_transport_config(&config);
169 let mut server_config = ServerConfig::with_single_cert(vec![cert.clone()], key.clone_key())
170 .map_err(|e| Error::Internal(format!("Failed to create server config: {}", e)))?;
171 server_config.transport_config(Arc::new(server_transport));
172
173 let client_transport = Self::create_transport_config(&config);
175 let client_crypto = rustls::ClientConfig::builder()
176 .dangerous()
177 .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
178 .with_no_client_auth();
179 let mut client_config = ClientConfig::new(Arc::new(
180 quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto).map_err(|e| {
181 Error::Internal(format!("Failed to create QUIC client config: {}", e))
182 })?,
183 ));
184 client_config.transport_config(Arc::new(client_transport));
185
186 let endpoint = Endpoint::server(server_config, config.bind_addr)
188 .map_err(|e| Error::Internal(format!("Failed to create QUIC endpoint: {}", e)))?;
189
190 Ok(Self {
191 endpoint,
192 pools: Arc::new(RwLock::new(HashMap::new())),
193 config,
194 client_config,
195 })
196 }
197
198 fn create_transport_config(config: &QuicConfig) -> TransportConfig {
200 let mut transport = TransportConfig::default();
201 transport.max_idle_timeout(Some(config.idle_timeout.try_into().unwrap_or_default()));
202 transport.max_concurrent_bidi_streams(config.max_streams.into());
203 transport.max_concurrent_uni_streams(config.max_streams.into());
204 transport.initial_mtu(1200);
205 transport
208 }
209
210 fn generate_self_signed_cert() -> Result<(
212 rustls::pki_types::CertificateDer<'static>,
213 rustls::pki_types::PrivateKeyDer<'static>,
214 )> {
215 let rcgen_cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
216 .map_err(|e| Error::Internal(format!("Failed to generate certificate: {}", e)))?;
217
218 let cert_der = rustls::pki_types::CertificateDer::from(rcgen_cert.cert.der().to_vec());
219 let key_der =
220 rustls::pki_types::PrivateKeyDer::try_from(rcgen_cert.signing_key.serialize_der())
221 .map_err(|e| Error::Internal(format!("Failed to serialize key: {}", e)))?;
222
223 Ok((cert_der, key_der))
224 }
225
226 pub fn local_addr(&self) -> Result<SocketAddr> {
228 self.endpoint
229 .local_addr()
230 .map_err(|e| Error::Internal(format!("Failed to get local address: {}", e)))
231 }
232
233 pub async fn connect(&self, addr: SocketAddr) -> Result<Connection> {
235 {
237 let mut pools = self.pools.write().await;
238 if let Some(pool) = pools.get_mut(&addr) {
239 if let Some(conn) = pool.get() {
240 conn.touch();
241 return Ok(conn.connection.clone());
242 }
243 }
244 }
245
246 let connection = self
248 .endpoint
249 .connect_with(self.client_config.clone(), addr, "localhost")
250 .map_err(|e| Error::Internal(format!("Failed to initiate connection: {}", e)))?
251 .await
252 .map_err(|e| Error::Internal(format!("Failed to connect: {}", e)))?;
253
254 {
256 let mut pools = self.pools.write().await;
257 let pool = pools.entry(addr).or_insert_with(|| {
258 PeerPool::new(self.config.pool_size, self.config.pool_idle_timeout)
259 });
260 pool.add(connection.clone());
261 }
262
263 Ok(connection)
264 }
265
266 pub async fn accept(&self) -> Result<Option<Connection>> {
268 if let Some(incoming) = self.endpoint.accept().await {
269 let connection = incoming
270 .await
271 .map_err(|e| Error::Internal(format!("Failed to accept connection: {}", e)))?;
272 Ok(Some(connection))
273 } else {
274 Ok(None)
275 }
276 }
277
278 pub async fn open_stream(&self, connection: &Connection) -> Result<(SendStream, RecvStream)> {
280 connection
281 .open_bi()
282 .await
283 .map_err(|e| Error::Internal(format!("Failed to open stream: {}", e)))
284 }
285
286 pub async fn send(&self, stream: &mut SendStream, data: &[u8]) -> Result<()> {
288 stream
289 .write_all(data)
290 .await
291 .map_err(|e| Error::Internal(format!("Failed to send data: {}", e)))?;
292 stream
293 .finish()
294 .map_err(|e| Error::Internal(format!("Failed to finish stream: {}", e)))?;
295 Ok(())
296 }
297
298 pub async fn receive(&self, stream: &mut RecvStream) -> Result<Vec<u8>> {
300 let data = stream
301 .read_to_end(self.config.max_message_size)
302 .await
303 .map_err(|e| Error::Internal(format!("Failed to receive data: {}", e)))?;
304 Ok(data)
305 }
306
307 pub async fn send_zero_copy(&self, stream: &mut SendStream, data: Bytes) -> Result<()> {
309 stream
310 .write_all(&data)
311 .await
312 .map_err(|e| Error::Internal(format!("Failed to send data: {}", e)))?;
313 stream
314 .finish()
315 .map_err(|e| Error::Internal(format!("Failed to finish stream: {}", e)))?;
316 Ok(())
317 }
318
319 pub async fn receive_zero_copy(&self, stream: &mut RecvStream) -> Result<Bytes> {
321 let data = stream
322 .read_to_end(self.config.max_message_size)
323 .await
324 .map_err(|e| Error::Internal(format!("Failed to receive data: {}", e)))?;
325 Ok(Bytes::from(data))
326 }
327
328 pub async fn forward_block(
330 &self,
331 recv_stream: &mut RecvStream,
332 send_stream: &mut SendStream,
333 ) -> Result<usize> {
334 let mut total_bytes = 0;
335 let mut buffer = vec![0u8; 16384]; loop {
338 let n = match recv_stream.read(&mut buffer).await {
339 Ok(Some(n)) => n,
340 Ok(None) => break,
341 Err(e) => return Err(Error::Internal(format!("Failed to read: {}", e))),
342 };
343
344 send_stream
345 .write_all(&buffer[..n])
346 .await
347 .map_err(|e| Error::Internal(format!("Failed to write: {}", e)))?;
348
349 total_bytes += n;
350 }
351
352 send_stream
353 .finish()
354 .map_err(|e| Error::Internal(format!("Failed to finish stream: {}", e)))?;
355
356 Ok(total_bytes)
357 }
358
359 pub async fn send_to(&self, addr: SocketAddr, data: &[u8]) -> Result<()> {
361 let connection = self.connect(addr).await?;
362 let (mut send, _recv) = self.open_stream(&connection).await?;
363 self.send(&mut send, data).await
364 }
365
366 pub async fn pool_stats(&self) -> QuicPoolStats {
368 let pools = self.pools.read().await;
369 let total_connections: usize = pools.values().map(|p| p.connection_count()).sum();
370 let peer_count = pools.len();
371
372 QuicPoolStats {
373 peer_count,
374 total_connections,
375 }
376 }
377
378 pub async fn cleanup_idle(&self) {
380 let mut pools = self.pools.write().await;
381 for pool in pools.values_mut() {
382 pool.connections
383 .retain(|c| c.is_healthy() && !c.is_idle(pool.idle_timeout));
384 }
385 pools.retain(|_, p| !p.connections.is_empty());
387 }
388
389 pub fn close(&self) {
391 self.endpoint.close(0u32.into(), b"shutdown");
392 }
393}
394
395#[derive(Debug, Clone)]
397pub struct QuicPoolStats {
398 pub peer_count: usize,
400 pub total_connections: usize,
402}
403
404#[derive(Debug)]
406struct SkipServerVerification;
407
408impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
409 fn verify_server_cert(
410 &self,
411 _end_entity: &rustls::pki_types::CertificateDer<'_>,
412 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
413 _server_name: &rustls::pki_types::ServerName<'_>,
414 _ocsp_response: &[u8],
415 _now: rustls::pki_types::UnixTime,
416 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
417 Ok(rustls::client::danger::ServerCertVerified::assertion())
418 }
419
420 fn verify_tls12_signature(
421 &self,
422 _message: &[u8],
423 _cert: &rustls::pki_types::CertificateDer<'_>,
424 _dss: &rustls::DigitallySignedStruct,
425 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
426 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
427 }
428
429 fn verify_tls13_signature(
430 &self,
431 _message: &[u8],
432 _cert: &rustls::pki_types::CertificateDer<'_>,
433 _dss: &rustls::DigitallySignedStruct,
434 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
435 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
436 }
437
438 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
439 vec![
440 rustls::SignatureScheme::RSA_PKCS1_SHA256,
441 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
442 rustls::SignatureScheme::RSA_PKCS1_SHA384,
443 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
444 rustls::SignatureScheme::RSA_PKCS1_SHA512,
445 rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
446 rustls::SignatureScheme::RSA_PSS_SHA256,
447 rustls::SignatureScheme::RSA_PSS_SHA384,
448 rustls::SignatureScheme::RSA_PSS_SHA512,
449 rustls::SignatureScheme::ED25519,
450 ]
451 }
452}
453
454pub struct BlockStream {
456 send: SendStream,
457 recv: RecvStream,
458}
459
460impl BlockStream {
461 pub fn new(send: SendStream, recv: RecvStream) -> Self {
463 Self { send, recv }
464 }
465
466 pub async fn send_request(&mut self, data: &[u8]) -> Result<()> {
468 self.send
469 .write_all(data)
470 .await
471 .map_err(|e| Error::Internal(format!("Failed to send request: {}", e)))?;
472 self.send
473 .finish()
474 .map_err(|e| Error::Internal(format!("Failed to finish stream: {}", e)))?;
475 Ok(())
476 }
477
478 pub async fn receive_response(&mut self, max_size: usize) -> Result<Vec<u8>> {
480 self.recv
481 .read_to_end(max_size)
482 .await
483 .map_err(|e| Error::Internal(format!("Failed to receive response: {}", e)))
484 }
485
486 pub async fn send_request_zero_copy(&mut self, data: Bytes) -> Result<()> {
488 self.send
489 .write_all(&data)
490 .await
491 .map_err(|e| Error::Internal(format!("Failed to send request: {}", e)))?;
492 self.send
493 .finish()
494 .map_err(|e| Error::Internal(format!("Failed to finish stream: {}", e)))?;
495 Ok(())
496 }
497
498 pub async fn receive_response_zero_copy(&mut self, max_size: usize) -> Result<Bytes> {
500 let data = self
501 .recv
502 .read_to_end(max_size)
503 .await
504 .map_err(|e| Error::Internal(format!("Failed to receive response: {}", e)))?;
505 Ok(Bytes::from(data))
506 }
507}
508
509pub struct ParallelRequester {
511 connection: Connection,
512 max_concurrent: usize,
513 #[allow(dead_code)]
515 max_message_size: usize,
516}
517
518impl ParallelRequester {
519 pub fn new(connection: Connection, max_concurrent: usize, max_message_size: usize) -> Self {
521 Self {
522 connection,
523 max_concurrent,
524 max_message_size,
525 }
526 }
527
528 pub async fn open_stream(&self) -> Result<BlockStream> {
530 let (send, recv) = self
531 .connection
532 .open_bi()
533 .await
534 .map_err(|e| Error::Internal(format!("Failed to open stream: {}", e)))?;
535 Ok(BlockStream::new(send, recv))
536 }
537
538 pub async fn execute_parallel<F, Fut, T>(&self, requests: Vec<F>) -> Vec<Result<T>>
540 where
541 F: FnOnce(BlockStream) -> Fut,
542 Fut: std::future::Future<Output = Result<T>> + Send,
543 T: Send,
544 {
545 use futures::stream::{self, StreamExt};
546
547 let max_concurrent = self.max_concurrent;
548
549 stream::iter(requests)
550 .map(|request| async move {
551 let stream = self.open_stream().await?;
552 request(stream).await
553 })
554 .buffer_unordered(max_concurrent)
555 .collect()
556 .await
557 }
558
559 pub fn max_concurrent(&self) -> usize {
561 self.max_concurrent
562 }
563}
564
565pub struct AdaptiveBatchTuner {
569 current_batch_size: usize,
571 min_batch_size: usize,
573 max_batch_size: usize,
575 completion_times: Vec<u64>,
577 window_size: usize,
579 target_throughput: f64,
581 last_adjustment: Instant,
583 adjustment_interval: Duration,
585}
586
587impl AdaptiveBatchTuner {
588 pub fn new(
590 initial_batch_size: usize,
591 min_batch_size: usize,
592 max_batch_size: usize,
593 target_throughput: f64,
594 ) -> Self {
595 Self {
596 current_batch_size: initial_batch_size,
597 min_batch_size,
598 max_batch_size,
599 completion_times: Vec::new(),
600 window_size: 10,
601 target_throughput,
602 last_adjustment: Instant::now(),
603 adjustment_interval: Duration::from_secs(1),
604 }
605 }
606
607 pub fn record_completion(&mut self, duration_ms: u64) {
609 self.completion_times.push(duration_ms);
610 if self.completion_times.len() > self.window_size {
611 self.completion_times.remove(0);
612 }
613 }
614
615 pub fn current_batch_size(&self) -> usize {
617 self.current_batch_size
618 }
619
620 pub fn adjust_batch_size(&mut self) -> usize {
622 if self.last_adjustment.elapsed() < self.adjustment_interval {
624 return self.current_batch_size;
625 }
626
627 if self.completion_times.len() < 3 {
629 return self.current_batch_size;
630 }
631
632 let avg_time =
634 self.completion_times.iter().sum::<u64>() as f64 / self.completion_times.len() as f64;
635
636 let current_throughput = (self.current_batch_size as f64 / avg_time) * 1000.0;
638
639 let new_batch_size = if current_throughput < self.target_throughput * 0.8 {
641 (self.current_batch_size as f64 * 1.2) as usize
643 } else if current_throughput > self.target_throughput * 1.2 {
644 (self.current_batch_size as f64 * 0.8) as usize
646 } else {
647 self.current_batch_size
649 };
650
651 self.current_batch_size = new_batch_size.clamp(self.min_batch_size, self.max_batch_size);
653 self.last_adjustment = Instant::now();
654 self.completion_times.clear();
655
656 self.current_batch_size
657 }
658
659 pub fn reset(&mut self) {
661 self.completion_times.clear();
662 self.last_adjustment = Instant::now();
663 }
664}
665
666impl Default for AdaptiveBatchTuner {
667 fn default() -> Self {
668 Self::new(32, 8, 128, 100.0)
669 }
670}
671
672#[derive(Debug, Clone)]
674pub struct PipelineConfig {
675 pub prefetch_depth: usize,
677 pub max_pipeline_size: usize,
679 pub enable_speculation: bool,
681}
682
683impl Default for PipelineConfig {
684 fn default() -> Self {
685 Self {
686 prefetch_depth: 4,
687 max_pipeline_size: 16,
688 enable_speculation: true,
689 }
690 }
691}
692
693pub struct SequentialPipeline {
697 connection: Connection,
699 config: PipelineConfig,
701 max_message_size: usize,
703 in_flight: Arc<RwLock<HashMap<u64, tokio::task::JoinHandle<Result<Bytes>>>>>,
705 next_index: Arc<RwLock<u64>>,
707}
708
709impl SequentialPipeline {
710 pub fn new(connection: Connection, config: PipelineConfig, max_message_size: usize) -> Self {
712 Self {
713 connection,
714 config,
715 max_message_size,
716 in_flight: Arc::new(RwLock::new(HashMap::new())),
717 next_index: Arc::new(RwLock::new(0)),
718 }
719 }
720
721 async fn start_request(&self, index: u64, request_data: Bytes) -> Result<()> {
723 let connection = self.connection.clone();
724 let max_size = self.max_message_size;
725
726 let handle = tokio::spawn(async move {
727 let (mut send, mut recv) = connection
728 .open_bi()
729 .await
730 .map_err(|e| Error::Internal(format!("Failed to open stream: {}", e)))?;
731
732 send.write_all(&request_data)
734 .await
735 .map_err(|e| Error::Internal(format!("Failed to send: {}", e)))?;
736 send.finish()
737 .map_err(|e| Error::Internal(format!("Failed to finish: {}", e)))?;
738
739 let data = recv
741 .read_to_end(max_size)
742 .await
743 .map_err(|e| Error::Internal(format!("Failed to receive: {}", e)))?;
744
745 Ok(Bytes::from(data))
746 });
747
748 let mut in_flight = self.in_flight.write().await;
749 in_flight.insert(index, handle);
750
751 Ok(())
752 }
753
754 pub async fn fetch_next(&self, request_data: Bytes) -> Result<Bytes> {
756 let current_index = {
757 let mut next = self.next_index.write().await;
758 let current = *next;
759 *next += 1;
760 current
761 };
762
763 if self.config.enable_speculation {
765 for i in 1..=self.config.prefetch_depth {
766 let prefetch_index = current_index + i as u64;
767
768 let in_flight = self.in_flight.read().await;
770 if !in_flight.contains_key(&prefetch_index) {
771 drop(in_flight);
772
773 let _ = self
775 .start_request(prefetch_index, request_data.clone())
776 .await;
777 }
778 }
779 }
780
781 let handle = {
783 let mut in_flight = self.in_flight.write().await;
784
785 if !in_flight.contains_key(¤t_index) {
787 drop(in_flight);
788 self.start_request(current_index, request_data).await?;
789 let mut in_flight = self.in_flight.write().await;
790 in_flight.remove(¤t_index)
791 } else {
792 in_flight.remove(¤t_index)
793 }
794 };
795
796 if let Some(handle) = handle {
797 handle
798 .await
799 .map_err(|e| Error::Internal(format!("Task failed: {}", e)))?
800 } else {
801 Err(Error::Internal("Request handle not found".to_string()))
802 }
803 }
804
805 pub async fn clear(&self) {
807 let mut in_flight = self.in_flight.write().await;
808 for (_, handle) in in_flight.drain() {
809 handle.abort();
810 }
811 }
812
813 pub async fn in_flight_count(&self) -> usize {
815 self.in_flight.read().await.len()
816 }
817}
818
819#[cfg(test)]
820mod tests {
821 use super::*;
822
823 #[test]
824 fn test_quic_config_defaults() {
825 let config = QuicConfig::default();
826 assert_eq!(config.max_streams, 256);
827 assert!(config.enable_0rtt);
828 assert_eq!(config.pool_size, 4);
829 }
830
831 #[test]
832 fn test_peer_pool() {
833 let pool = PeerPool::new(4, Duration::from_secs(60));
835 assert_eq!(pool.connection_count(), 0);
836 }
837}