frpc-lib 0.1.0

Library for Fluid RPC to dynamically invoke gRPC services
Documentation
use std::collections::{HashMap, HashSet};

use anyhow::{bail, Result};
use prost::Message;
use prost_reflect::DescriptorPool;
use prost_types::FileDescriptorProto;
use tokio_stream::StreamExt;
use tonic::{
    transport::{Channel, Endpoint},
    Request,
};

use tonic_reflection::pb::v1::{
    server_reflection_client::ServerReflectionClient, server_reflection_request::MessageRequest,
    server_reflection_response::MessageResponse, FileDescriptorResponse, ServerReflectionRequest,
};

pub(crate) async fn load_from_server_reflection(server_url: String) -> Result<DescriptorPool> {
    let mut pool = DescriptorPool::new();

    let connection = get_client(server_url).await;

    match connection {
        Ok(client) => match populate_descriptor_pool(&mut pool, client).await {
            Ok(_) => Ok(pool),
            Err(e) => bail!("Error getting Server Reflections: {}", e),
        },
        Err(e) => bail!("Error creating gRPC connection: {}", e),
    }
}

pub(crate) async fn check_implemented(server_url: String) -> Result<bool> {
    let connection = get_client(server_url).await;

    match connection {
        Ok(mut client) => {
            let list_services_request = ServerReflectionRequest {
                host: String::new(),
                message_request: Some(MessageRequest::ListServices(String::new())),
            };

            let services_response = match client
                .server_reflection_info(Request::new(tokio_stream::once(list_services_request)))
                .await
            {
                Ok(resp) => resp,
                Err(e) => {
                    if e.code() == tonic::Code::Unimplemented {
                        return Ok(false);
                    } else {
                        bail!(e);
                    }
                }
            };

            if let Some(MessageResponse::ListServicesResponse(services)) = services_response
                .into_inner()
                .next()
                .await
                .expect("Response did not exist")
                .expect("Message did not exist")
                .message_response
            {
                Ok(true)
            } else {
                Ok(false)
            }
        }
        Err(e) => bail!(e),
    }
}

async fn get_client(server_url: String) -> Result<ServerReflectionClient<Channel>> {
    let connection = Endpoint::new(server_url);

    match connection {
        Ok(endpoint) => {
            let conn = endpoint.connect().await;

            match conn {
                Ok(channel) => Ok(ServerReflectionClient::new(channel)),
                Err(e) => {
                    dbg!(&e);
                    bail!(e)
                }
            }
        }
        Err(e) => {
            dbg!(&e);
            bail!(e)
        }
    }
}

async fn populate_descriptor_pool(
    pool: &mut DescriptorPool,
    mut client: ServerReflectionClient<Channel>,
) -> Result<()> {
    let list_services_request = ServerReflectionRequest {
        host: String::new(),
        message_request: Some(MessageRequest::ListServices(String::new())),
    };

    let services_response = match client
        .server_reflection_info(Request::new(tokio_stream::once(list_services_request)))
        .await
    {
        Ok(resp) => resp,
        Err(e) => bail!(e),
    };

    if let Some(MessageResponse::ListServicesResponse(services)) = services_response
        .into_inner()
        .next()
        .await
        .expect("Response did not exist")
        .expect("Message did not exist")
        .message_response
    {
        let service_file_paths = services.service.into_iter().map(|s| s.name).collect();

        let all_file_bytes = reflect_decode_files_bytes(&mut client, service_file_paths)
            .await
            .expect("Failed to get file bytes");

        let mut dependency_paths: HashSet<String> = HashSet::new();

        let mut decoded_depdendencies: HashMap<String, FileDescriptorProto> = HashMap::new();
        let mut decoded_files: HashMap<String, FileDescriptorProto> = HashMap::new();

        for bytes in all_file_bytes {
            let slice = bytes.as_slice();
            let file_desc = FileDescriptorProto::decode(slice);

            match file_desc {
                Ok(decoded) => {
                    let key = decoded.name();

                    for dep in &decoded.dependency {
                        dependency_paths.insert(dep.to_owned());
                    }

                    decoded_files.insert(key.to_string(), decoded);
                }
                _ => {}
            };
        }

        let dependency_file_bytes =
            reflect_decode_files_bytes(&mut client, dependency_paths.into_iter().collect())
                .await
                .expect("Failed to get dependency file bytes");

        for bytes in dependency_file_bytes {
            let slice = bytes.as_slice();
            let file_desc = FileDescriptorProto::decode(slice);

            match file_desc {
                Ok(decoded) => {
                    let key = decoded.name();

                    decoded_depdendencies.insert(key.to_string(), decoded);
                }
                _ => {}
            }
        }

        match pool.add_file_descriptor_protos(
            decoded_depdendencies
                .values()
                .cloned()
                .collect::<Vec<FileDescriptorProto>>(),
        ) {
            Ok(_) => {}
            Err(e) => bail!("Failed to add dependency to DescriptorPool: {}", e),
        };

        match pool.add_file_descriptor_protos(
            decoded_files
                .values()
                .cloned()
                .collect::<Vec<FileDescriptorProto>>(),
        ) {
            Ok(_) => {}
            Err(e) => bail!("Failed to add file to DescriptorPool: {}", e),
        };
    }

    Ok(())
}

async fn reflect_decode_files_bytes(
    client: &mut ServerReflectionClient<Channel>,
    file_paths: Vec<String>,
) -> Result<Vec<Vec<u8>>> {
    let service_requests: Vec<ServerReflectionRequest> = file_paths
        .into_iter()
        .map(|path| ServerReflectionRequest {
            host: String::new(),
            message_request: Some(MessageRequest::FileContainingSymbol(path)),
        })
        .collect();

    let file_responses = client
        .server_reflection_info(Request::new(tokio_stream::iter(service_requests)))
        .await?
        .into_inner()
        .collect::<Vec<_>>()
        .await
        .into_iter()
        .filter_map(Result::ok)
        .filter_map(|resp| resp.message_response)
        .filter_map(|msg| match msg {
            MessageResponse::FileDescriptorResponse(files) => Some(files),
            _ => None,
        })
        .collect::<Vec<FileDescriptorResponse>>();

    Ok(file_responses
        .into_iter()
        .flat_map(|f| f.file_descriptor_proto)
        .collect::<Vec<Vec<u8>>>())
}