use std::net::SocketAddr;
use std::sync::Arc;
use quinn::{ClientConfig, Endpoint, RecvStream, SendStream, ServerConfig};
use crate::codec::Codec;
use crate::error::{KnafehError, RpcStatusCode};
use crate::rpc::message::{Metadata, RpcRequest, RpcResponse, RpcStatus};
use crate::rpc::middleware::MiddlewareStack;
use crate::rpc::router::MethodRouter;
use crate::transport::quic_wire;
pub struct QuicClient {
connection: quinn::Connection,
codec: Arc<dyn Codec>,
middleware: Arc<MiddlewareStack>,
}
impl QuicClient {
pub async fn connect_insecure(
addr: SocketAddr,
codec: Arc<dyn Codec>,
middleware: Arc<MiddlewareStack>,
) -> Result<Self, KnafehError> {
let crypto = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(SkipVerification))
.with_no_client_auth();
let client_cfg = ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(crypto)
.map_err(|e| KnafehError::Tls(e.to_string()))?,
));
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())
.map_err(|e| KnafehError::Transport(e.to_string()))?;
endpoint.set_default_client_config(client_cfg);
let connection = endpoint
.connect(addr, "localhost")
.map_err(|e| KnafehError::Transport(e.to_string()))?
.await
.map_err(|e| KnafehError::Transport(e.to_string()))?;
Ok(Self {
connection,
codec,
middleware,
})
}
pub async fn call(&self, method: &str, body: Vec<u8>) -> Result<RpcResponse, KnafehError> {
let encoded_body = self.codec.encode(&body)?;
let mut request = RpcRequest {
method: method.to_string(),
metadata: Metadata::new(),
body: encoded_body,
};
self.middleware.apply_request(&mut request).await?;
let (mut send, mut recv) = self
.connection
.open_bi()
.await
.map_err(|e| KnafehError::Transport(format!("open_bi failed: {e}")))?;
let req_bytes =
quic_wire::encode_request(&request.method, &request.body, &request.metadata);
send.write_all(&req_bytes)
.await
.map_err(|e| KnafehError::Transport(format!("write failed: {e}")))?;
send.finish()
.map_err(|e| KnafehError::Transport(format!("finish failed: {e}")))?;
let resp_bytes = recv
.read_to_end(super::quic_wire::MAX_MESSAGE_SIZE)
.await
.map_err(|e| KnafehError::Transport(format!("read failed: {e}")))?;
let (status_code, status_message, resp_body, metadata) =
quic_wire::decode_response(&resp_bytes)?;
let mut response = RpcResponse {
status: RpcStatus {
code: status_code,
message: status_message,
},
metadata,
body: resp_body,
};
self.middleware.apply_response(&mut response).await?;
if response.status.is_ok() {
response.body = self.codec.decode(&response.body)?;
}
Ok(response)
}
}
pub struct QuicServer {
endpoint: Endpoint,
router: Arc<MethodRouter>,
codec: Arc<dyn Codec>,
middleware: Arc<MiddlewareStack>,
}
impl QuicServer {
pub fn bind(
addr: SocketAddr,
cert_pem: &str,
key_pem: &str,
router: MethodRouter,
codec: Arc<dyn Codec>,
middleware: MiddlewareStack,
) -> Result<Self, KnafehError> {
let certs = rustls_pemfile::certs(&mut cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()
.map_err(|e| KnafehError::Tls(format!("cert parse error: {e}")))?;
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())
.map_err(|e| KnafehError::Tls(format!("key parse error: {e}")))?
.ok_or_else(|| KnafehError::Tls("no private key found".into()))?;
let server_crypto = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| KnafehError::Tls(e.to_string()))?;
let server_config = ServerConfig::with_crypto(Arc::new(
quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto)
.map_err(|e| KnafehError::Tls(e.to_string()))?,
));
let endpoint = Endpoint::server(server_config, addr)
.map_err(|e| KnafehError::Transport(e.to_string()))?;
Ok(Self {
endpoint,
router: Arc::new(router),
codec,
middleware: Arc::new(middleware),
})
}
pub fn local_addr(&self) -> Result<SocketAddr, KnafehError> {
self.endpoint
.local_addr()
.map_err(|e| KnafehError::Transport(e.to_string()))
}
pub async fn serve(&self) -> Result<(), KnafehError> {
while let Some(incoming) = self.endpoint.accept().await {
let conn = incoming
.await
.map_err(|e| KnafehError::Transport(format!("accept failed: {e}")))?;
let router = Arc::clone(&self.router);
let codec = Arc::clone(&self.codec);
let middleware = Arc::clone(&self.middleware);
tokio::spawn(async move {
loop {
let stream = conn.accept_bi().await;
match stream {
Ok((send, recv)) => {
let router = Arc::clone(&router);
let codec = Arc::clone(&codec);
let middleware = Arc::clone(&middleware);
tokio::spawn(async move {
if let Err(e) =
handle_stream(send, recv, &router, codec.as_ref(), &middleware)
.await
{
tracing::warn!(error = %e, "stream handler error");
}
});
}
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
Err(e) => {
tracing::warn!(error = %e, "accept_bi error");
break;
}
}
}
});
}
Ok(())
}
}
async fn handle_stream(
mut send: SendStream,
mut recv: RecvStream,
router: &MethodRouter,
codec: &dyn Codec,
middleware: &MiddlewareStack,
) -> Result<(), KnafehError> {
let req_bytes = recv
.read_to_end(super::quic_wire::MAX_MESSAGE_SIZE)
.await
.map_err(|e| KnafehError::Transport(format!("read failed: {e}")))?;
let (method, raw_body, metadata) = quic_wire::decode_request(&req_bytes)?;
let body = codec.decode(&raw_body)?;
let mut request = RpcRequest {
method,
metadata,
body,
};
middleware.apply_request(&mut request).await?;
let mut response = match router.route_unary(request).await {
Ok(resp) => resp,
Err(KnafehError::Service { code, message }) => RpcResponse::error(code, message),
Err(e) => RpcResponse::error(RpcStatusCode::Internal, e.to_string()),
};
middleware.apply_response(&mut response).await?;
let encoded_body = codec.encode(&response.body)?;
let resp_bytes = quic_wire::encode_response(
response.status.code,
&response.status.message,
&encoded_body,
&response.metadata,
);
send.write_all(&resp_bytes)
.await
.map_err(|e| KnafehError::Transport(format!("write failed: {e}")))?;
send.finish()
.map_err(|e| KnafehError::Transport(format!("finish failed: {e}")))?;
Ok(())
}
#[derive(Debug)]
struct SkipVerification;
impl rustls::client::danger::ServerCertVerifier for SkipVerification {
fn verify_server_cert(
&self,
_: &rustls::pki_types::CertificateDer,
_: &[rustls::pki_types::CertificateDer],
_: &rustls::pki_types::ServerName,
_: &[u8],
_: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_: &[u8],
_: &rustls::pki_types::CertificateDer,
_: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_: &[u8],
_: &rustls::pki_types::CertificateDer,
_: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}