prost-build 0.14.3

Generate Prost annotated Rust types from Protocol Buffers files.
Documentation
use std::borrow::Cow;

use prost_types::{
    field_descriptor_proto::{Label, Type},
    FieldDescriptorProto,
};

use crate::extern_paths::ExternPaths;
use crate::message_graph::MessageGraph;
use crate::{BytesType, Config, MapType, ServiceGenerator};

/// The context providing all the global information needed to generate code.
/// It also provides a more disciplined access to Config
/// and its mutable instance of ServiceGenerator.
///
/// A `Context` is built once in the generation process and is reused by
/// `CodeGenerator` instances created to generate code for each input file.
pub struct Context<'a> {
    config: &'a mut Config,
    message_graph: MessageGraph,
    extern_paths: ExternPaths,
    prost_path_attribute: Option<String>,
}

impl<'a> Context<'a> {
    pub fn new(
        config: &'a mut Config,
        message_graph: MessageGraph,
        extern_paths: ExternPaths,
    ) -> Self {
        let prost_path_attribute = config
            .prost_path
            .as_deref()
            .map(|prost_path| format!(r#"#[prost(prost_path = "{prost_path}")]"#));

        Self {
            config,
            message_graph,
            extern_paths,
            prost_path_attribute,
        }
    }

    pub fn config(&self) -> &Config {
        self.config
    }

    pub fn service_generator_mut(&mut self) -> Option<&mut (dyn ServiceGenerator + 'static)> {
        self.config.service_generator.as_deref_mut()
    }

    pub fn prost_path(&self) -> &str {
        self.config.prost_path_or_default()
    }

    pub fn prost_path_attribute(&self) -> Option<&str> {
        self.prost_path_attribute.as_deref()
    }

    pub fn resolve_extern_ident(&self, pb_ident: &str) -> Option<String> {
        self.extern_paths.resolve_ident(pb_ident)
    }

    /// Returns an iterator over the additional attributes configured
    /// for the named type.
    pub fn type_attributes(&self, fq_type_name: &str) -> impl Iterator<Item = &str> {
        self.config
            .type_attributes
            .get(fq_type_name)
            .map(|s| s.as_str())
    }

    /// Returns an iterator over the additional attributes configured
    /// for the named message.
    pub fn message_attributes(&self, fq_message_name: &str) -> impl Iterator<Item = &str> {
        self.config
            .message_attributes
            .get(fq_message_name)
            .map(|s| s.as_str())
    }

    /// Returns an iterator over the additional attributes configured
    /// for the named enum.
    pub fn enum_attributes(&self, fq_enum_name: &str) -> impl Iterator<Item = &str> {
        self.config
            .enum_attributes
            .get(fq_enum_name)
            .map(|s| s.as_str())
    }

    /// Returns an iterator over the additional attributes configured
    /// for the named message field.
    pub fn field_attributes(
        &self,
        fq_message_name: &str,
        field_name: &str,
    ) -> impl Iterator<Item = &str> {
        self.config
            .field_attributes
            .get_field(fq_message_name, field_name)
            .map(|s| s.as_str())
    }

    /// Returns the bytes type configured for the named message field.
    pub(crate) fn bytes_type(&self, fq_message_name: &str, field_name: &str) -> BytesType {
        self.config
            .bytes_type
            .get_first_field(fq_message_name, field_name)
            .copied()
            .unwrap_or_default()
    }

    /// Returns the map type configured for the named message field.
    pub(crate) fn map_type(&self, fq_message_name: &str, field_name: &str) -> MapType {
        self.config
            .map_type
            .get_first_field(fq_message_name, field_name)
            .copied()
            .unwrap_or_default()
    }

    /// Returns whether the Rust type for this message field needs to be `Box<_>`.
    ///
    /// This can be explicitly configured with `Config::boxed`, or necessary
    /// to prevent an infinitely sized type definition in case when the type of
    /// a non-repeated message field transitively contains the message itself.
    pub fn should_box_message_field(
        &self,
        fq_message_name: &str,
        field: &FieldDescriptorProto,
    ) -> bool {
        self.should_box_impl(fq_message_name, None, field)
    }

    /// Returns whether the Rust type for this field in the oneof needs to be `Box<_>`.
    ///
    /// This can be explicitly configured with `Config::boxed`, or necessary
    /// to prevent an infinitely sized type definition in case when the type of
    /// a non-repeated message field transitively contains the message itself.
    pub fn should_box_oneof_field(
        &self,
        fq_message_name: &str,
        oneof_name: &str,
        field: &FieldDescriptorProto,
    ) -> bool {
        self.should_box_impl(fq_message_name, Some(oneof_name), field)
    }

    fn should_box_impl(
        &self,
        fq_message_name: &str,
        oneof: Option<&str>,
        field: &FieldDescriptorProto,
    ) -> bool {
        if field.label() == Label::Repeated {
            // Repeated field are stored in Vec, therefore it is already heap allocated
            return false;
        }
        let fd_type = field.r#type();
        if (fd_type == Type::Message || fd_type == Type::Group)
            && self
                .message_graph
                .is_nested(field.type_name(), fq_message_name)
        {
            return true;
        }
        let config_path = match oneof {
            None => Cow::Borrowed(fq_message_name),
            Some(oneof_name) => Cow::Owned(format!("{fq_message_name}.{oneof_name}")),
        };
        if self
            .config
            .boxed
            .get_first_field(&config_path, field.name())
            .is_some()
        {
            return true;
        }
        false
    }

    /// Returns `true` if this message can automatically derive Copy trait.
    pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
        assert_eq!(".", &fq_message_name[..1]);
        self.message_graph
            .get_message(fq_message_name)
            .unwrap()
            .field
            .iter()
            .all(|field| self.can_field_derive_copy(fq_message_name, field))
    }

    /// Returns `true` if the type of this message field allows deriving the Copy trait.
    pub fn can_field_derive_copy(
        &self,
        fq_message_name: &str,
        field: &FieldDescriptorProto,
    ) -> bool {
        assert_eq!(".", &fq_message_name[..1]);

        // repeated field cannot derive Copy
        if field.label() == Label::Repeated {
            false
        } else if field.r#type() == Type::Message {
            // nested and boxed messages cannot derive Copy
            if self
                .message_graph
                .is_nested(field.type_name(), fq_message_name)
            {
                return false;
            }
            if self
                .config
                .boxed
                .get_first_field(fq_message_name, field.name())
                .is_some()
            {
                false
            } else {
                self.can_message_derive_copy(field.type_name())
            }
        } else {
            matches!(
                field.r#type(),
                Type::Float
                    | Type::Double
                    | Type::Int32
                    | Type::Int64
                    | Type::Uint32
                    | Type::Uint64
                    | Type::Sint32
                    | Type::Sint64
                    | Type::Fixed32
                    | Type::Fixed64
                    | Type::Sfixed32
                    | Type::Sfixed64
                    | Type::Bool
                    | Type::Enum
            )
        }
    }

    /// Returns `true` if this message can automatically derive Eq trait.
    pub fn can_message_derive_eq(&self, fq_message_name: &str) -> bool {
        assert_eq!(".", &fq_message_name[..1]);

        let msg = self.message_graph.get_message(fq_message_name).unwrap();
        msg.field
            .iter()
            .all(|field| self.can_field_derive_eq(fq_message_name, field))
    }

    /// Returns `true` if the type of this field allows deriving the Eq trait.
    pub fn can_field_derive_eq(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool {
        assert_eq!(".", &fq_message_name[..1]);

        if field.r#type() == Type::Message {
            if field.label() == Label::Repeated
                || self
                    .message_graph
                    .is_nested(field.type_name(), fq_message_name)
            {
                false
            } else {
                self.can_message_derive_eq(field.type_name())
            }
        } else {
            matches!(
                field.r#type(),
                Type::Int32
                    | Type::Int64
                    | Type::Uint32
                    | Type::Uint64
                    | Type::Sint32
                    | Type::Sint64
                    | Type::Fixed32
                    | Type::Fixed64
                    | Type::Sfixed32
                    | Type::Sfixed64
                    | Type::Bool
                    | Type::Enum
                    | Type::String
                    | Type::Bytes
            )
        }
    }

    pub fn should_disable_comments(&self, fq_message_name: &str, field_name: Option<&str>) -> bool {
        if let Some(field_name) = field_name {
            self.config
                .disable_comments
                .get_first_field(fq_message_name, field_name)
                .is_some()
        } else {
            self.config
                .disable_comments
                .get(fq_message_name)
                .next()
                .is_some()
        }
    }

    /// Returns whether the named message should skip generating the `Debug` implementation.
    pub fn should_skip_debug(&self, fq_message_name: &str) -> bool {
        assert_eq!(b'.', fq_message_name.as_bytes()[0]);
        self.config.skip_debug.get(fq_message_name).next().is_some()
    }

    /// Returns the type name domain URL for the named message,
    /// or an empty string if such is not configured.
    pub fn type_name_domain(&self, fq_message_name: &str) -> &str {
        self.config
            .type_name_domains
            .get_first(fq_message_name)
            .map_or("", |name| name.as_str())
    }
}