use bytes::{Buf, Bytes};
use prost::{
encoding::{check_wire_type, decode_key, decode_varint, skip_field, DecodeContext, WireType},
DecodeError, Message,
};
use prost_types::FileDescriptorProto;
use crate::{
file::{File, FileResolver},
Error,
};
#[derive(Debug)]
pub struct DescriptorSetFileResolver {
set: Vec<FileDescriptor>,
}
#[derive(Debug, Clone, Default, PartialEq)]
struct FileDescriptor {
file: FileDescriptorProto,
encoded: Option<Bytes>,
}
impl DescriptorSetFileResolver {
pub fn new(set: prost_types::FileDescriptorSet) -> Self {
DescriptorSetFileResolver {
set: set
.file
.into_iter()
.map(|file| FileDescriptor {
encoded: None,
file,
})
.collect(),
}
}
pub fn decode<B>(mut buf: B) -> Result<Self, DecodeError>
where
B: Buf,
{
const FILE_TAG: u32 = 1;
let mut set = Vec::new();
while buf.has_remaining() {
let (key, wire_type) = decode_key(&mut buf)?;
if key == FILE_TAG {
check_wire_type(WireType::LengthDelimited, wire_type)?;
let len = decode_varint(&mut buf)? as usize;
if len > buf.remaining() {
return Err(buffer_underflow_error());
}
set.push(FileDescriptor::decode((&mut buf).take(len))?);
} else {
skip_field(wire_type, key, &mut buf, DecodeContext::default())?;
}
}
Ok(DescriptorSetFileResolver { set })
}
}
impl FileResolver for DescriptorSetFileResolver {
fn open_file(&self, name: &str) -> Result<File, Error> {
for file in &self.set {
if file.file.name() == name {
return Ok(File {
path: None,
source: None,
descriptor: file.file.clone(),
encoded: file.encoded.clone(),
});
}
}
Err(Error::file_not_found(name))
}
}
impl FileDescriptor {
fn decode(mut buf: impl Buf) -> Result<Self, DecodeError> {
let encoded = buf.copy_to_bytes(buf.remaining());
let file = FileDescriptorProto::decode(&mut encoded.as_ref())?;
Ok(FileDescriptor {
file,
encoded: Some(encoded),
})
}
}
fn buffer_underflow_error() -> DecodeError {
prost::encoding::skip_field(
WireType::ThirtyTwoBit,
1,
&mut [].as_slice(),
DecodeContext::default(),
)
.unwrap_err()
}