use std::{
collections::{HashMap, HashSet},
fmt::{self, Write},
path::{Path, PathBuf},
};
use prost::Message;
use prost_reflect::{DescriptorPool, DynamicMessage, ReflectMessage, Value};
use prost_types::{FileDescriptorProto, FileDescriptorSet};
use crate::{
error::{Error, ErrorKind},
file::{check_shadow, path_to_file_name, File, FileMetadata, FileResolver},
};
#[cfg(test)]
mod tests;
pub struct Compiler {
pool: DescriptorPool,
resolver: Box<dyn FileResolver>,
files: HashMap<String, FileMetadata>,
include_imports: bool,
include_source_info: bool,
}
impl Compiler {
pub fn new<I, P>(includes: I) -> Result<Self, Error>
where
I: IntoIterator<Item = P>,
P: AsRef<Path>,
{
use crate::file::{ChainFileResolver, GoogleFileResolver, IncludeFileResolver};
let mut resolver = ChainFileResolver::new();
for include in includes {
resolver.add(IncludeFileResolver::new(include.as_ref().to_owned()));
}
resolver.add(GoogleFileResolver::new());
Ok(Compiler::with_file_resolver(resolver))
}
pub fn with_file_resolver<R>(resolver: R) -> Self
where
R: FileResolver + 'static,
{
Compiler {
pool: DescriptorPool::new(),
resolver: Box::new(resolver),
files: HashMap::new(),
include_imports: false,
include_source_info: false,
}
}
pub fn include_source_info(&mut self, yes: bool) -> &mut Self {
self.include_source_info = yes;
self
}
pub fn include_imports(&mut self, yes: bool) -> &mut Self {
self.include_imports = yes;
self
}
pub fn open_file(&mut self, path: impl AsRef<Path>) -> Result<&mut Self, Error> {
let path = path.as_ref();
let (name, is_resolved) = if let Some(name) = self.resolver.resolve_path(path) {
(name, true)
} else if let Some(name) = path_to_file_name(path) {
(name, false)
} else {
return Err(Error::from_kind(ErrorKind::FileNotIncluded {
path: path.to_owned(),
}));
};
if let Some(file_metadata) = self.files.get_mut(&name) {
if is_resolved {
check_shadow(&name, file_metadata.path(), path)?;
}
file_metadata.is_import = false;
return Ok(self);
}
let file = self.resolver.open_file(&name).map_err(|err| {
if err.is_file_not_found() {
Error::from_kind(ErrorKind::FileNotIncluded {
path: path.to_owned(),
})
} else {
err
}
})?;
if is_resolved {
check_shadow(&name, file.path(), path)?;
}
let mut import_stack = vec![name.clone()];
let mut already_imported = HashSet::new();
for (i, import) in file.descriptor.dependency.iter().enumerate() {
if !already_imported.insert(import) {
return Err(Error::duplicated_import(import.to_owned(), &file, i));
}
self.add_import(import, &mut import_stack)
.map_err(|e| e.into_import_error(&file, i))?;
}
let path = self.check_file(file)?;
self.files.insert(
name.clone(),
FileMetadata {
name,
path,
is_import: false,
},
);
Ok(self)
}
pub fn open_files(
&mut self,
paths: impl IntoIterator<Item = impl AsRef<Path>>,
) -> Result<&mut Self, Error> {
for path in paths {
self.open_file(path)?;
}
Ok(self)
}
pub fn file_descriptor_set(&self) -> prost_types::FileDescriptorSet {
let file = self
.pool
.files()
.filter(|f| self.include_imports || !self.files[f.name()].is_import)
.map(|f| {
if self.include_source_info {
f.file_descriptor_proto().clone()
} else {
prost_types::FileDescriptorProto {
source_code_info: None,
..f.file_descriptor_proto().clone()
}
}
})
.collect();
prost_types::FileDescriptorSet { file }
}
pub fn encode_file_descriptor_set(&self) -> Vec<u8> {
if self.include_imports && self.include_source_info {
return self.pool.encode_to_vec();
}
let file_desc = FileDescriptorProto::default().descriptor();
let files = self
.pool
.files()
.filter(|f| self.include_imports || !self.files[f.name()].is_import)
.map(|f| {
let file_buf = f.encode_to_vec();
let mut file_msg =
DynamicMessage::decode(file_desc.clone(), file_buf.as_slice()).unwrap();
if !self.include_source_info {
file_msg.clear_field_by_name("source_code_info");
}
Value::Message(file_msg)
})
.collect();
let mut file_descriptor_set = FileDescriptorSet::default().transcode_to_dynamic();
file_descriptor_set.set_field_by_name("file", Value::List(files));
file_descriptor_set.encode_to_vec()
}
pub fn descriptor_pool(&self) -> DescriptorPool {
self.pool.clone()
}
pub fn files(&self) -> impl ExactSizeIterator<Item = &'_ FileMetadata> {
self.pool.files().map(|f| &self.files[f.name()])
}
fn add_import(&mut self, file_name: &str, import_stack: &mut Vec<String>) -> Result<(), Error> {
if import_stack.iter().any(|name| name == file_name) {
let mut cycle = String::new();
for import in import_stack {
write!(&mut cycle, "{} -> ", import).unwrap();
}
write!(&mut cycle, "{}", file_name).unwrap();
return Err(Error::from_kind(ErrorKind::CircularImport {
name: file_name.to_owned(),
cycle,
}));
}
if self.files.contains_key(file_name) {
return Ok(());
}
let file = self.resolver.open_file(file_name)?;
import_stack.push(file_name.to_owned());
let mut already_imported = HashSet::new();
for (i, import) in file.descriptor.dependency.iter().enumerate() {
if !already_imported.insert(import) {
return Err(Error::duplicated_import(import.to_owned(), &file, i));
}
self.add_import(import, import_stack)
.map_err(|e| e.into_import_error(&file, i))?;
}
import_stack.pop();
let path = self.check_file(file)?;
self.files.insert(
file_name.to_owned(),
FileMetadata {
name: file_name.to_owned(),
path,
is_import: true,
},
);
Ok(())
}
fn check_file(
&mut self,
File {
path,
source,
descriptor,
encoded,
}: File,
) -> Result<Option<PathBuf>, Error> {
if let Some(encoded) = &encoded {
self.pool.decode_file_descriptor_proto(encoded.clone())
} else {
self.pool.add_file_descriptor_proto(descriptor)
}
.map_err(|mut err| {
if let Some(source) = source {
err = err.with_source_code(&source);
}
err
})?;
Ok(path)
}
}
impl fmt::Debug for Compiler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Compiler")
.field("include_imports", &self.include_imports)
.field("include_source_info", &self.include_source_info)
.finish_non_exhaustive()
}
}