use std::{fmt, sync::Arc};
use tokio::sync::mpsc;
use tokio_stream::{Stream, StreamExt};
use tonic::{Request, Response, Status, Streaming};
use super::ReflectionServiceState;
use crate::pb::v1alpha::server_reflection_request::MessageRequest;
use crate::pb::v1alpha::server_reflection_response::MessageResponse;
pub use crate::pb::v1alpha::server_reflection_server::{ServerReflection, ServerReflectionServer};
use crate::pb::v1alpha::{
ExtensionNumberResponse, FileDescriptorResponse, ListServiceResponse, ServerReflectionRequest,
ServerReflectionResponse, ServiceResponse,
};
#[derive(Debug)]
pub struct ReflectionService {
state: Arc<ReflectionServiceState>,
}
#[tonic::async_trait]
impl ServerReflection for ReflectionService {
type ServerReflectionInfoStream = ServerReflectionInfoStream;
async fn server_reflection_info(
&self,
req: Request<Streaming<ServerReflectionRequest>>,
) -> Result<Response<Self::ServerReflectionInfoStream>, Status> {
let mut req_rx = req.into_inner();
let (resp_tx, resp_rx) = mpsc::channel::<Result<ServerReflectionResponse, Status>>(1);
let state = self.state.clone();
tokio::spawn(async move {
while let Some(req) = req_rx.next().await {
let Ok(req) = req else {
return;
};
let resp_msg = match req.message_request.clone() {
None => Err(Status::invalid_argument("invalid MessageRequest")),
Some(msg) => match msg {
MessageRequest::FileByFilename(s) => state.file_by_filename(&s).map(|fd| {
MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
file_descriptor_proto: vec![fd],
})
}),
MessageRequest::FileContainingSymbol(s) => {
state.symbol_by_name(&s).map(|fd| {
MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
file_descriptor_proto: vec![fd],
})
})
}
MessageRequest::FileContainingExtension(_) => {
Err(Status::not_found("extensions are not supported"))
}
MessageRequest::AllExtensionNumbersOfType(_) => {
Ok(MessageResponse::AllExtensionNumbersResponse(
ExtensionNumberResponse::default(),
))
}
MessageRequest::ListServices(_) => {
Ok(MessageResponse::ListServicesResponse(ListServiceResponse {
service: state
.list_services()
.iter()
.map(|s| ServiceResponse { name: s.clone() })
.collect(),
}))
}
},
};
match resp_msg {
Ok(resp_msg) => {
let resp = ServerReflectionResponse {
valid_host: req.host.clone(),
original_request: Some(req.clone()),
message_response: Some(resp_msg),
};
if resp_tx.send(Ok(resp)).await.is_err() {
return;
}
}
Err(status) => {
let _ = resp_tx.send(Err(status)).await;
return;
}
}
}
});
Ok(Response::new(ServerReflectionInfoStream::new(resp_rx)))
}
}
impl From<ReflectionServiceState> for ReflectionService {
fn from(state: ReflectionServiceState) -> Self {
Self {
state: Arc::new(state),
}
}
}
pub struct ServerReflectionInfoStream {
inner: tokio_stream::wrappers::ReceiverStream<Result<ServerReflectionResponse, Status>>,
}
impl ServerReflectionInfoStream {
fn new(resp_rx: mpsc::Receiver<Result<ServerReflectionResponse, Status>>) -> Self {
let inner = tokio_stream::wrappers::ReceiverStream::new(resp_rx);
Self { inner }
}
}
impl Stream for ServerReflectionInfoStream {
type Item = Result<ServerReflectionResponse, Status>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::pin::Pin::new(&mut self.inner).poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl fmt::Debug for ServerReflectionInfoStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ServerReflectionInfoStream").finish()
}
}