ferricel-core 0.2.1

Core compiler and runtime library for ferricel (CEL → Wasm)
//! Protocol Buffer schema support for proper wrapper type semantics.
//!
//! This module handles parsing FileDescriptorSet (binary proto descriptors) to extract
//! type information needed for implementing CEL wrapper type semantics:
//!
//! - Wrapper-to-primitive equality: `BoolValue{value: true} == true`
//! - Empty wrapper zero values: `BoolValue{} == false`
//! - Unset wrapper field access: `TestAllTypes{}.single_bool_wrapper == null`
//! - Any unpacking for equality: `Any{...} == Any{...}` compares decoded fields

use std::collections::HashMap;

use anyhow::{Context, Result};
use prost::Message;
use prost_types::{DescriptorProto, FileDescriptorSet};

/// The kind of a proto field, used to guide schema-aware wire comparison.
#[derive(Debug, Clone, PartialEq)]
pub enum FieldKind {
    /// Primitive field (varint, fixed32, fixed64) — compare raw wire bytes
    Primitive,
    /// Bytes field (wire type 2, but not embedded message) — compare raw bytes
    Bytes,
    /// Embedded message field — recurse with schema for the given fully-qualified type name
    Message(String),
}

/// The 9 official Protocol Buffer wrapper types from google/protobuf/wrappers.proto
const WRAPPER_TYPES: &[&str] = &[
    "google.protobuf.BoolValue",
    "google.protobuf.BytesValue",
    "google.protobuf.DoubleValue",
    "google.protobuf.FloatValue",
    "google.protobuf.Int32Value",
    "google.protobuf.Int64Value",
    "google.protobuf.StringValue",
    "google.protobuf.UInt32Value",
    "google.protobuf.UInt64Value",
];

/// Schema extracted from Protocol Buffer FileDescriptorSet.
///
/// Contains message type information needed for implementing proper wrapper type semantics.
#[derive(Debug, Clone)]
pub struct ProtoSchema {
    /// Maps fully-qualified message type names to their schemas
    messages: HashMap<String, MessageSchema>,
}

/// Schema information for a single Protocol Buffer message type.
#[derive(Debug, Clone)]
struct MessageSchema {
    /// Maps field names to their type information
    fields: HashMap<String, FieldSchema>,
    /// True if this message is a map entry (i.e., generated for a map<K,V> field)
    is_map_entry: bool,
}

/// Schema information for a message field.
#[derive(Debug, Clone)]
struct FieldSchema {
    /// Field name
    name: String,
    /// Fully-qualified type name for message/enum fields (e.g., "google.protobuf.BoolValue")
    type_name: Option<String>,
    /// Proto field number (1-based, as declared in the .proto file)
    number: u32,
    /// Field kind for wire comparison
    kind: FieldKind,
    /// True if this field is declared as `repeated` (includes map fields)
    repeated: bool,
}

impl ProtoSchema {
    /// Parse a FileDescriptorSet from bytes.
    ///
    /// The bytes should be the output of `protoc --descriptor_set_out=...`
    ///
    /// # Example
    ///
    /// ```no_run
    /// use ferricel_core::schema::ProtoSchema;
    ///
    /// let descriptor_bytes = std::fs::read("types.pb").unwrap();
    /// let schema = ProtoSchema::from_descriptor_set(&descriptor_bytes).unwrap();
    /// ```
    pub fn from_descriptor_set(bytes: &[u8]) -> Result<Self> {
        let fds = FileDescriptorSet::decode(bytes)
            .context("Failed to decode FileDescriptorSet. Make sure the file was generated by protoc with --descriptor_set_out")?;

        let mut messages = HashMap::new();

        // Process each file in the descriptor set
        for file in fds.file {
            let package = file.package.as_deref().unwrap_or("");

            // Process all message types in the file
            Self::process_messages(&mut messages, package, "", &file.message_type);
        }

        Ok(Self { messages })
    }

    /// Recursively process message types and nested types.
    fn process_messages(
        messages: &mut HashMap<String, MessageSchema>,
        package: &str,
        parent_path: &str,
        message_types: &[DescriptorProto],
    ) {
        for message in message_types {
            let message_name = message.name.as_deref().unwrap_or("");

            // Build fully-qualified name
            let full_name = if parent_path.is_empty() {
                if package.is_empty() {
                    message_name.to_string()
                } else {
                    format!("{}.{}", package, message_name)
                }
            } else {
                format!("{}.{}", parent_path, message_name)
            };

            // Check if this message is a map entry
            let is_map_entry = message
                .options
                .as_ref()
                .and_then(|o| o.map_entry)
                .unwrap_or(false);

            // Process fields
            let mut fields = HashMap::new();
            for field in &message.field {
                let field_name = field.name.as_deref().unwrap_or("").to_string();
                let type_name = field.type_name.as_ref().map(|tn| {
                    // Remove leading dot if present
                    tn.strip_prefix('.').unwrap_or(tn).to_string()
                });
                let number = field.number.unwrap_or(0) as u32;

                // Determine if repeated
                use prost_types::field_descriptor_proto::Label as FLabel;
                let repeated = field.label() == FLabel::Repeated;

                // Determine field kind from the proto field type
                // prost_types::field_descriptor_proto::Type values:
                //   TYPE_BYTES = 12, TYPE_STRING = 9 → both wire type 2 but Bytes/Primitive
                //   TYPE_MESSAGE = 11, TYPE_GROUP = 10 → wire type 2, embedded message
                //   All others → primitive (varint or fixed)
                use prost_types::field_descriptor_proto::Type as FType;
                let kind = match field.r#type() {
                    FType::Message | FType::Group => {
                        let msg_type = type_name.clone().unwrap_or_default();
                        FieldKind::Message(msg_type)
                    }
                    FType::Bytes => FieldKind::Bytes,
                    _ => FieldKind::Primitive,
                };

                fields.insert(
                    field_name.clone(),
                    FieldSchema {
                        name: field_name,
                        type_name,
                        number,
                        kind,
                        repeated,
                    },
                );
            }

            messages.insert(
                full_name.clone(),
                MessageSchema {
                    fields,
                    is_map_entry,
                },
            );

            // Process nested types
            Self::process_messages(messages, package, &full_name, &message.nested_type);
        }
    }

    /// Check if a type name is one of the official wrapper types.
    ///
    /// # Example
    ///
    /// ```
    /// use ferricel_core::schema::ProtoSchema;
    ///
    /// assert!(ProtoSchema::is_wrapper_type("google.protobuf.BoolValue"));
    /// assert!(ProtoSchema::is_wrapper_type("google.protobuf.Int64Value"));
    /// assert!(!ProtoSchema::is_wrapper_type("google.protobuf.Timestamp"));
    /// assert!(!ProtoSchema::is_wrapper_type("my.custom.Message"));
    /// ```
    pub fn is_wrapper_type(type_name: &str) -> bool {
        WRAPPER_TYPES.contains(&type_name)
    }

    /// Get the list of field names that are wrapper types for a given message type.
    ///
    /// Returns an empty vector if the message type is not in the schema.
    ///
    /// # Example
    ///
    /// ```no_run
    /// use ferricel_core::schema::ProtoSchema;
    ///
    /// let bytes: Vec<u8> = vec![];
    /// let schema = ProtoSchema::from_descriptor_set(&bytes).unwrap();
    /// let wrapper_fields = schema.get_wrapper_fields("cel.expr.conformance.proto3.TestAllTypes");
    /// // Returns: ["single_bool_wrapper", "single_int64_wrapper", ...]
    /// ```
    pub fn get_wrapper_fields(&self, message_type: &str) -> Vec<String> {
        self.messages
            .get(message_type)
            .map(|msg| {
                msg.fields
                    .values()
                    .filter(|field| {
                        field
                            .type_name
                            .as_ref()
                            .map(|tn| Self::is_wrapper_type(tn))
                            .unwrap_or(false)
                    })
                    .map(|field| field.name.clone())
                    .collect()
            })
            .unwrap_or_default()
    }

    /// Check if a message type exists in the schema.
    pub fn has_message_type(&self, message_type: &str) -> bool {
        self.messages.contains_key(message_type)
    }

    /// Get the default value kind for repeated/map fields of a message type.
    ///
    /// Only `repeated` fields are returned, since they have non-trivial proto defaults
    /// (`{}` for map fields, `[]` for list fields) that must be returned when the field
    /// is absent from a struct literal.  Non-repeated message fields and scalars are
    /// intentionally omitted — their absence is either caught by wrapper-field logic or
    /// is a genuine runtime error.
    ///
    /// Returns a map of `field_name → "map" | "list"`.
    /// Returns an empty map if the message type is not in the schema.
    pub fn get_field_default_kinds(&self, message_type: &str) -> HashMap<String, String> {
        self.messages
            .get(message_type)
            .map(|msg| {
                msg.fields
                    .values()
                    .filter_map(|field| {
                        if !field.repeated {
                            return None; // only repeated fields need a collection default
                        }
                        // Distinguish map fields (repeated MapEntry) from plain lists
                        let kind_str = match &field.kind {
                            FieldKind::Message(type_name)
                                if self
                                    .messages
                                    .get(type_name.as_str())
                                    .map(|m| m.is_map_entry)
                                    .unwrap_or(false) =>
                            {
                                "map"
                            }
                            _ => "list",
                        };
                        Some((field.name.clone(), kind_str.to_string()))
                    })
                    .collect()
            })
            .unwrap_or_default()
    }

    /// Get field schema for Any unpacking: returns a map of field_number → encoded kind string.
    ///
    /// The encoded kind strings are:
    /// - `"primitive"` — compare raw wire bytes (varint / fixed)
    /// - `"bytes"` — compare raw length-delimited payload bytes
    /// - `"message:<fqn>"` — embedded message, recurse with the named type's schema
    ///
    /// Returns an empty map if the message type is not in the schema.
    pub fn get_any_field_schema(&self, message_type: &str) -> HashMap<u32, String> {
        self.messages
            .get(message_type)
            .map(|msg| {
                msg.fields
                    .values()
                    .map(|f| {
                        let kind_str = match &f.kind {
                            FieldKind::Primitive => "primitive".to_string(),
                            FieldKind::Bytes => "bytes".to_string(),
                            FieldKind::Message(fqn) => format!("message:{}", fqn),
                        };
                        (f.number, kind_str)
                    })
                    .collect()
            })
            .unwrap_or_default()
    }

    /// Get all message type names in the schema.
    pub fn message_types(&self) -> Vec<&str> {
        self.messages.keys().map(|s| s.as_str()).collect()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_is_wrapper_type() {
        // Official wrapper types
        assert!(ProtoSchema::is_wrapper_type("google.protobuf.BoolValue"));
        assert!(ProtoSchema::is_wrapper_type("google.protobuf.BytesValue"));
        assert!(ProtoSchema::is_wrapper_type("google.protobuf.DoubleValue"));
        assert!(ProtoSchema::is_wrapper_type("google.protobuf.FloatValue"));
        assert!(ProtoSchema::is_wrapper_type("google.protobuf.Int32Value"));
        assert!(ProtoSchema::is_wrapper_type("google.protobuf.Int64Value"));
        assert!(ProtoSchema::is_wrapper_type("google.protobuf.StringValue"));
        assert!(ProtoSchema::is_wrapper_type("google.protobuf.UInt32Value"));
        assert!(ProtoSchema::is_wrapper_type("google.protobuf.UInt64Value"));

        // Not wrapper types
        assert!(!ProtoSchema::is_wrapper_type("google.protobuf.Timestamp"));
        assert!(!ProtoSchema::is_wrapper_type("google.protobuf.Duration"));
        assert!(!ProtoSchema::is_wrapper_type("google.protobuf.Any"));
        assert!(!ProtoSchema::is_wrapper_type("google.protobuf.Empty"));
        assert!(!ProtoSchema::is_wrapper_type("my.custom.Message"));
    }

    #[test]
    fn test_empty_descriptor_set() {
        let fds = FileDescriptorSet { file: vec![] };
        let bytes = fds.encode_to_vec();
        let schema = ProtoSchema::from_descriptor_set(&bytes).unwrap();
        assert_eq!(schema.message_types().len(), 0);
    }
}