grpcweb-cli 0.1.0

A simple command-line client for gRPC-Web
use std::path::PathBuf;
use std::sync::Arc;

use bytes::{Buf, BufMut};
use clap::Parser;
use hyper_util::rt::TokioExecutor;
use protobuf::descriptor::FileDescriptorProto;
use protobuf::reflect::{FileDescriptor, MessageDescriptor};
use protobuf::MessageDyn;
use tonic::client::Grpc;
use tonic::{
  codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder},
  Code, Status,
};
use tonic_web::GrpcWebClientLayer;

/// A custom codec for Tonic
#[derive(Debug, Clone)]
struct TonicCodec {
  message_descriptor: Arc<MessageDescriptor>,
}

impl TonicCodec {
  fn new(message_descriptor: MessageDescriptor) -> Self {
    Self {
      message_descriptor: Arc::new(message_descriptor),
    }
  }
}

impl Codec for TonicCodec {
  type Encode = Box<dyn MessageDyn>;
  type Decode = Box<dyn MessageDyn>;

  type Encoder = TonicEncoder;
  type Decoder = TonicDecoder;

  fn encoder(&mut self) -> Self::Encoder {
    TonicEncoder
  }

  fn decoder(&mut self) -> Self::Decoder {
    TonicDecoder {
      message_descriptor: self.message_descriptor.clone(),
    }
  }
}

#[derive(Debug, Clone)]
struct TonicEncoder;

impl Encoder for TonicEncoder {
  type Item = Box<dyn MessageDyn>;
  type Error = Status;

  fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
    let mut writer = buf.writer();
    item
      .write_to_writer_dyn(&mut writer)
      .expect("Message only errors if not enough space");

    Ok(())
  }
}

#[derive(Debug, Clone)]
struct TonicDecoder {
  message_descriptor: Arc<MessageDescriptor>,
}

impl Decoder for TonicDecoder {
  type Item = Box<dyn MessageDyn>;
  type Error = Status;

  fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
    let mut reader = buf.reader();
    let item = self
      .message_descriptor
      .parse_from_reader(&mut reader)
      .map_err(|error| Status::new(Code::Internal, error.to_string()))?;

    Ok(Some(item))
  }
}

/// A simple command-line client for gRPC-Web
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
  /// Path to the Protobuf definition file.
  #[arg(short, long)]
  proto: PathBuf,

  /// Path to the Protobuf include directory. Default would be a parent directory of the proto file.
  #[arg(short, long)]
  include: Option<PathBuf>,

  /// Whether to use `protoc` to generate the Protobuf definition file.
  #[arg(long, default_value_t = false)]
  protoc: bool,

  /// The JSON-encoded request data that would be sent to the server.
  #[arg(short, long)]
  data: Option<String>,

  /// The destination URL (along with the gRPC service route)
  #[arg(short, long)]
  url: hyper::Uri,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
  let args = Args::parse();

  let mut protobuf_parser = protobuf_parse::Parser::new();
  if args.protoc {
    protobuf_parser.protoc();
  } else {
    protobuf_parser.pure();
  }
  let mut parsed = protobuf_parser
    .input(&args.proto)
    .include(
      args.include.unwrap_or(
        args
          .proto
          .parent()
          .map_or(PathBuf::new(), |p| p.to_path_buf()),
      ),
    )
    .parse_and_typecheck()?;
  let file_descriptor_proto: FileDescriptorProto = parsed
    .file_descriptors
    .pop()
    .ok_or(anyhow::anyhow!("No file descriptor found"))?;
  let file_descriptor: FileDescriptor = FileDescriptor::new_dynamic(file_descriptor_proto, &[])?;
  let split = args.url.path().rsplit_once('/');
  let method_name = split.map(|(_, m)| m).unwrap_or_default();
  let service_name = split
    .map(|(s, _)| s)
    .unwrap_or_default()
    .rsplit_once('.')
    .map(|(_, s)| s)
    .unwrap_or_default();
  let service = file_descriptor
    .services()
    .find(|service| service.proto().name() == service_name)
    .ok_or(anyhow::anyhow!("Service not found"))?;
  let method = service
    .methods()
    .find(|method| method.proto().name() == method_name)
    .ok_or(anyhow::anyhow!("Method not found"))?;
  let input_descriptor = file_descriptor
    .message_by_package_relative_name(method.input_type().name_to_package())
    .ok_or(anyhow::anyhow!("Input message not found"))?;
  let output_descriptor = file_descriptor
    .message_by_package_relative_name(method.output_type().name_to_package())
    .ok_or(anyhow::anyhow!("Output message not found"))?;

  let mut input_message = input_descriptor.new_instance();
  if let Some(data) = args.data {
    protobuf_json_mapping::merge_from_str(&mut *input_message, &data)?;
  }
  let codec = TonicCodec::new(output_descriptor);

  let url = args.url;
  let path_and_query = url
    .path_and_query()
    .ok_or(anyhow::anyhow!("Invalid URL"))?
    .to_owned();
  let is_client_streaming = method.proto().has_client_streaming();
  let is_server_streaming = method.proto().has_server_streaming();

  // Have to use Tokio, since many examples use HTTP clients that depend on Tokio...
  let runtime = tokio::runtime::Builder::new_current_thread()
    .enable_all()
    .build()?;
  let message = runtime.block_on(async move {
    let connector = hyper_rustls::HttpsConnectorBuilder::new()
      .with_native_roots()?
      .https_or_http()
      .enable_all_versions()
      .build();
    let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector);

    let svc = tower::ServiceBuilder::new()
      .layer(GrpcWebClientLayer::new())
      .service(client);

    let mut url_parts = url.into_parts();
    url_parts.path_and_query = Some("/".try_into()?);
    let url = url_parts.try_into()?;
    let mut client = Grpc::with_origin(svc, url);

    let response_data = if is_client_streaming {
      let response = client
        .client_streaming(
          tonic::Request::new(futures_util::stream::once(async move { input_message })),
          path_and_query,
          codec,
        )
        .await?;
      let message = response.into_inner();
      protobuf_json_mapping::print_to_string(&*message)?
    } else if is_server_streaming {
      let response = client
        .server_streaming(tonic::Request::new(input_message), path_and_query, codec)
        .await?;
      let mut message = response.into_inner();
      let mut final_data = String::new();
      final_data.push('[');
      let mut first = false;
      while let Some(message) = message.message().await? {
        final_data.push_str(&protobuf_json_mapping::print_to_string(&*message)?);
        if !first {
          final_data.push(',');
        }
        first = false;
      }
      final_data.push(']');
      final_data
    } else {
      let response = client
        .unary(tonic::Request::new(input_message), path_and_query, codec)
        .await?;
      let message = response.into_inner();
      protobuf_json_mapping::print_to_string(&*message)?
    };

    Ok::<String, Box<dyn std::error::Error>>(response_data)
  })?;

  println!("{}", message);

  Ok(())
}