use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use osproxy_spi::HttpMethod;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream};
use tokio_stream::StreamExt;
use tonic::transport::server::Connected;
use tonic::{Request, Response, Status};
use crate::classify::classify;
use crate::handler::IngressHandler;
use crate::request::IngressRequest;
use crate::tls::CryptoProvider;
mod pb {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::all,
missing_docs,
unreachable_pub,
dead_code
)]
tonic::include_proto!("osproxy.v1");
}
use pb::document_service_server::{DocumentService, DocumentServiceServer};
use pb::{IndexReply, IndexRequest};
const AUTHORIZATION: &str = "authorization";
struct GrpcIngress<H> {
handler: Arc<H>,
}
#[tonic::async_trait]
impl<H: IngressHandler> DocumentService for GrpcIngress<H> {
async fn index(&self, request: Request<IndexRequest>) -> Result<Response<IndexReply>, Status> {
let headers = bearer_header(&request);
let conn = request.extensions().get::<GrpcConnInfo>();
let client_cert_subject = conn.and_then(|i| i.client_cert_subject.clone());
let secure = conn.is_some_and(|i| i.secure);
let msg = request.into_inner();
let (method, path) = if msg.id.is_empty() {
(HttpMethod::Post, format!("/{}/_doc", msg.index))
} else {
(HttpMethod::Put, format!("/{}/_doc/{}", msg.index, msg.id))
};
let c = classify(method, &path);
let ingress = IngressRequest {
method,
protocol: osproxy_spi::Protocol::Grpc,
path,
endpoint: c.endpoint,
logical_index: c.logical_index,
doc_id: c.doc_id,
headers,
body: msg.document,
query: None,
client_cert_subject,
secure,
};
let resp = self.handler.handle(ingress).await;
let request_id = resp
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("x-request-id"))
.map(|(_, v)| v.clone())
.unwrap_or_default();
Ok(Response::new(IndexReply {
status: u32::from(resp.status),
body: resp.body,
request_id,
}))
}
}
fn bearer_header<T>(request: &Request<T>) -> Vec<(String, String)> {
request
.metadata()
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.map(|token| vec![(AUTHORIZATION.to_owned(), token.to_owned())])
.unwrap_or_default()
}
pub async fn serve_grpc<H: IngressHandler>(
listener: TcpListener,
handler: Arc<H>,
) -> Result<(), tonic::transport::Error> {
let service = DocumentServiceServer::new(GrpcIngress { handler });
tonic::transport::Server::builder()
.add_service(service)
.serve_with_incoming(TcpListenerStream::new(listener))
.await
}
#[derive(Clone, Default)]
struct GrpcConnInfo {
client_cert_subject: Option<String>,
secure: bool,
}
struct TlsConn {
inner: tokio_rustls::server::TlsStream<TcpStream>,
info: GrpcConnInfo,
}
impl Connected for TlsConn {
type ConnectInfo = GrpcConnInfo;
fn connect_info(&self) -> GrpcConnInfo {
self.info.clone()
}
}
impl AsyncRead for TlsConn {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl AsyncWrite for TlsConn {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
pub async fn serve_grpc_tls<H, P>(
listener: TcpListener,
provider: Arc<P>,
handler: Arc<H>,
) -> Result<(), tonic::transport::Error>
where
H: IngressHandler,
P: CryptoProvider,
{
let acceptor = TlsAcceptor::from(provider.server_config());
let (tx, rx) = tokio::sync::mpsc::channel::<TlsConn>(32);
tokio::spawn(async move {
while let Ok((tcp, _peer)) = listener.accept().await {
let acceptor = acceptor.clone();
let tx = tx.clone();
tokio::spawn(async move {
if let Ok(tls) = acceptor.accept(tcp).await {
let info = GrpcConnInfo {
client_cert_subject: crate::tls::client_subject_from_tls(&tls),
secure: true,
};
let _ = tx.send(TlsConn { inner: tls, info }).await;
}
});
}
});
let incoming = ReceiverStream::new(rx).map(Ok::<_, io::Error>);
let service = DocumentServiceServer::new(GrpcIngress { handler });
tonic::transport::Server::builder()
.add_service(service)
.serve_with_incoming(incoming)
.await
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use osproxy_core::EndpointKind;
use super::pb::document_service_client::DocumentServiceClient;
use super::*;
use crate::request::IngressResponse;
#[derive(Default)]
struct RecordingHandler {
seen: Mutex<Option<IngressRequest>>,
}
impl IngressHandler for RecordingHandler {
async fn handle(&self, req: IngressRequest) -> IngressResponse {
*self.seen.lock().expect("lock") = Some(req);
IngressResponse::json(201, br#"{"result":"created"}"#.to_vec())
.with_header("x-request-id", "req-7")
}
}
#[tokio::test]
async fn index_rpc_drives_the_shared_handler_and_maps_the_reply() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("addr");
let handler = Arc::new(RecordingHandler::default());
let handler_for_server = Arc::clone(&handler);
tokio::spawn(async move { serve_grpc(listener, handler_for_server).await });
let mut client = DocumentServiceClient::connect(format!("http://{addr}"))
.await
.expect("connect");
let reply = client
.index(IndexRequest {
index: "orders".to_owned(),
id: "acme:1".to_owned(),
document: br#"{"msg":"hi"}"#.to_vec(),
})
.await
.expect("rpc")
.into_inner();
assert_eq!(reply.status, 201);
assert_eq!(reply.request_id, "req-7");
assert_eq!(reply.body, br#"{"result":"created"}"#);
let seen = handler.seen.lock().expect("lock").clone().expect("seen");
assert_eq!(seen.method, HttpMethod::Put);
assert_eq!(seen.path, "/orders/_doc/acme:1");
assert_eq!(seen.endpoint, EndpointKind::IngestDoc);
assert_eq!(seen.logical_index, "orders");
assert_eq!(seen.doc_id.as_deref(), Some("acme:1"));
assert_eq!(seen.body, br#"{"msg":"hi"}"#);
}
#[tokio::test]
async fn mtls_index_surfaces_the_verified_client_identity() {
use rcgen::{BasicConstraints, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair};
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
let ca_key = KeyPair::generate().expect("ca key");
let mut ca_params = CertificateParams::new(Vec::new()).expect("ca params");
ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
let ca = ca_params.self_signed(&ca_key).expect("ca");
let srv_key = KeyPair::generate().expect("srv key");
let srv = CertificateParams::new(vec!["localhost".to_owned()])
.expect("srv params")
.signed_by(&srv_key, &ca, &ca_key)
.expect("srv cert");
let cli_key = KeyPair::generate().expect("cli key");
let mut cli_params = CertificateParams::new(Vec::new()).expect("cli params");
cli_params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth];
let cli = cli_params
.signed_by(&cli_key, &ca, &ca_key)
.expect("cli cert");
let provider = crate::DefaultCryptoProvider::from_pem_mtls(
srv.pem().as_bytes(),
srv_key.serialize_pem().as_bytes(),
ca.pem().as_bytes(),
)
.expect("provider");
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let port = listener.local_addr().expect("addr").port();
let handler = Arc::new(RecordingHandler::default());
let server_handler = Arc::clone(&handler);
tokio::spawn(
async move { serve_grpc_tls(listener, Arc::new(provider), server_handler).await },
);
let tls = ClientTlsConfig::new()
.ca_certificate(Certificate::from_pem(ca.pem()))
.identity(Identity::from_pem(cli.pem(), cli_key.serialize_pem()))
.domain_name("localhost");
let channel = Channel::from_shared(format!("https://localhost:{port}"))
.expect("uri")
.tls_config(tls)
.expect("tls")
.connect()
.await
.expect("connect");
let mut client = DocumentServiceClient::new(channel);
client
.index(IndexRequest {
index: "orders".to_owned(),
id: "acme:1".to_owned(),
document: b"{}".to_vec(),
})
.await
.expect("rpc");
let seen = handler.seen.lock().expect("lock").clone().expect("seen");
let subject = seen.client_cert_subject.expect("client cert subject");
assert!(subject.starts_with("cert:"), "got {subject}");
}
}