Skip to main content

mockforge_contracts/contract_drift/
grpc_contract.rs

1//! gRPC contract implementation for protocol-agnostic contract drift detection
2//!
3//! This module provides a `GrpcContract` struct that implements the `ProtocolContract` trait
4//! for gRPC services, enabling drift detection and analysis for protobuf-based APIs.
5
6use crate::contract_drift::protocol_contracts::{
7    ContractError, ContractOperation, ContractRequest, OperationType, ProtocolContract,
8    ValidationError, ValidationResult,
9};
10use mockforge_foundation::contract_diff_types::{
11    ContractDiffResult, Mismatch, MismatchSeverity, MismatchType,
12};
13use mockforge_foundation::protocol::Protocol;
14use prost_reflect::{DescriptorPool, MessageDescriptor, MethodDescriptor, ServiceDescriptor};
15use std::collections::HashMap;
16use std::sync::Arc;
17
18/// gRPC contract implementation
19///
20/// Wraps a protobuf descriptor pool and provides contract drift detection
21/// capabilities for gRPC services and methods.
22pub struct GrpcContract {
23    /// Unique identifier for this contract
24    contract_id: String,
25    /// Contract version
26    version: String,
27    /// Descriptor pool containing the proto definitions
28    #[allow(dead_code)]
29    descriptor_pool: Arc<DescriptorPool>,
30    /// Map of service names to service descriptors
31    services: HashMap<String, ServiceDescriptor>,
32    /// Map of operation IDs (service.method) to method descriptors
33    methods: HashMap<String, MethodDescriptor>,
34    /// Cached contract operations for quick lookup
35    operations_cache: HashMap<String, ContractOperation>,
36    /// Contract metadata
37    metadata: HashMap<String, String>,
38}
39
40impl GrpcContract {
41    /// Create a new gRPC contract from a descriptor pool
42    pub fn new(
43        contract_id: String,
44        version: String,
45        descriptor_pool: Arc<DescriptorPool>,
46    ) -> Result<Self, ContractError> {
47        let mut services = HashMap::new();
48        let mut methods = HashMap::new();
49
50        let mut operations_cache = HashMap::new();
51
52        // Extract all services and methods from the descriptor pool
53        for service in descriptor_pool.services() {
54            let service_name = service.full_name().to_string();
55            services.insert(service_name.clone(), service.clone());
56
57            // Extract methods from this service
58            for method in service.methods() {
59                let method_name = method.name().to_string();
60                let operation_id = format!("{}.{}", service_name, method_name);
61                methods.insert(operation_id.clone(), method.clone());
62
63                // Cache the contract operation
64                let operation = ContractOperation {
65                    id: operation_id.clone(),
66                    name: method_name.clone(),
67                    operation_type: OperationType::GrpcMethod {
68                        service: service_name.clone(),
69                        method: method_name,
70                    },
71                    input_schema: Some(serde_json::json!({
72                        "type": method.input().full_name(),
73                        "streaming": method.is_client_streaming(),
74                    })),
75                    output_schema: Some(serde_json::json!({
76                        "type": method.output().full_name(),
77                        "streaming": method.is_server_streaming(),
78                    })),
79                    metadata: HashMap::new(),
80                };
81                operations_cache.insert(operation_id, operation);
82            }
83        }
84
85        Ok(Self {
86            contract_id,
87            version,
88            descriptor_pool,
89            services,
90            methods,
91            operations_cache,
92            metadata: HashMap::new(),
93        })
94    }
95
96    /// Create a gRPC contract from a proto file path
97    pub async fn from_proto_file(
98        contract_id: String,
99        version: String,
100        proto_file: &str,
101    ) -> Result<Self, ContractError> {
102        let proto_path = std::path::PathBuf::from(proto_file);
103        let parent = proto_path.parent().ok_or_else(|| {
104            ContractError::Other(format!(
105                "Could not determine parent directory for proto file: {}",
106                proto_file
107            ))
108        })?;
109
110        let descriptor_file = tempfile::Builder::new()
111            .prefix("mockforge-grpc-")
112            .suffix(".desc")
113            .tempfile()
114            .map_err(|e| ContractError::Other(format!("Failed to create temp file: {}", e)))?;
115        let descriptor_path = descriptor_file.path().to_path_buf();
116
117        let output = std::process::Command::new("protoc")
118            .arg("--include_imports")
119            .arg(format!("--descriptor_set_out={}", descriptor_path.to_string_lossy()))
120            .arg(format!("--proto_path={}", parent.to_string_lossy()))
121            .arg(proto_path.to_string_lossy().to_string())
122            .output()
123            .map_err(|e| {
124                ContractError::Other(format!(
125                    "Failed to execute protoc. Is it installed and in PATH? {}",
126                    e
127                ))
128            })?;
129
130        if !output.status.success() {
131            let stderr = String::from_utf8_lossy(&output.stderr);
132            return Err(ContractError::Other(format!(
133                "protoc failed for {}: {}",
134                proto_file, stderr
135            )));
136        }
137
138        let descriptor_bytes = std::fs::read(&descriptor_path).map_err(|e| {
139            ContractError::Other(format!(
140                "Failed to read generated descriptor set for {}: {}",
141                proto_file, e
142            ))
143        })?;
144
145        Self::from_descriptor_set(contract_id, version, &descriptor_bytes)
146    }
147
148    /// Create a gRPC contract from a compiled descriptor set (FileDescriptorSet bytes)
149    pub fn from_descriptor_set(
150        contract_id: String,
151        version: String,
152        descriptor_bytes: &[u8],
153    ) -> Result<Self, ContractError> {
154        let mut descriptor_pool = DescriptorPool::new();
155        descriptor_pool.decode_file_descriptor_set(descriptor_bytes).map_err(|e| {
156            ContractError::InvalidFormat(format!("Failed to decode descriptor set: {}", e))
157        })?;
158
159        Self::new(contract_id, version, Arc::new(descriptor_pool))
160    }
161
162    /// Compare two gRPC contracts and detect differences
163    fn diff_services(&self, other: &GrpcContract) -> Result<ContractDiffResult, ContractError> {
164        let mut mismatches = Vec::new();
165        let all_services: std::collections::HashSet<String> =
166            self.services.keys().chain(other.services.keys()).cloned().collect();
167
168        // Check for removed services (breaking change)
169        for service_name in &all_services {
170            if self.services.contains_key(service_name)
171                && !other.services.contains_key(service_name)
172            {
173                let mut context = HashMap::new();
174                context.insert("is_additive".to_string(), serde_json::json!(false));
175                context.insert("is_breaking".to_string(), serde_json::json!(true));
176                context.insert("change_category".to_string(), serde_json::json!("service_removed"));
177                context.insert("service".to_string(), serde_json::json!(service_name));
178
179                mismatches.push(Mismatch {
180                    mismatch_type: MismatchType::EndpointNotFound,
181                    path: service_name.clone(),
182                    method: None,
183                    expected: Some(format!("Service {} should exist", service_name)),
184                    actual: Some("Service removed".to_string()),
185                    description: format!("Service {} was removed", service_name),
186                    severity: MismatchSeverity::Critical,
187                    confidence: 1.0,
188                    context,
189                });
190            }
191        }
192
193        // Check for added services (non-breaking, additive)
194        for service_name in &all_services {
195            if !self.services.contains_key(service_name)
196                && other.services.contains_key(service_name)
197            {
198                let mut context = HashMap::new();
199                context.insert("is_additive".to_string(), serde_json::json!(true));
200                context.insert("is_breaking".to_string(), serde_json::json!(false));
201                context.insert("change_category".to_string(), serde_json::json!("service_added"));
202                context.insert("service".to_string(), serde_json::json!(service_name));
203
204                mismatches.push(Mismatch {
205                    mismatch_type: MismatchType::UnexpectedField,
206                    path: service_name.clone(),
207                    method: None,
208                    expected: None,
209                    actual: Some(format!("New service {}", service_name)),
210                    description: format!("New service {} was added", service_name),
211                    severity: MismatchSeverity::Low,
212                    confidence: 1.0,
213                    context,
214                });
215            }
216        }
217
218        // Compare methods in common services
219        for service_name in &all_services {
220            if let (Some(old_service), Some(new_service)) =
221                (self.services.get(service_name), other.services.get(service_name))
222            {
223                let method_diff = self.diff_methods(old_service, new_service)?;
224                mismatches.extend(method_diff);
225            }
226        }
227
228        let matches = mismatches.is_empty();
229        let confidence = if matches { 1.0 } else { 0.8 };
230
231        Ok(ContractDiffResult {
232            matches,
233            confidence,
234            mismatches,
235            recommendations: Vec::new(),
236            corrections: Vec::new(),
237            metadata: mockforge_foundation::contract_diff_types::DiffMetadata {
238                analyzed_at: chrono::Utc::now(),
239                request_source: "grpc_contract_diff".to_string(),
240                contract_version: Some(self.version.clone()),
241                contract_format: "protobuf".to_string(),
242                endpoint_path: "".to_string(),
243                http_method: "".to_string(),
244                request_count: 1,
245                llm_provider: None,
246                llm_model: None,
247            },
248        })
249    }
250
251    /// Classify a proto change as additive, breaking, or both
252    ///
253    /// Returns (is_additive, is_breaking) tuple
254    /// - Additive: New methods, new optional fields, new services
255    /// - Breaking: Removed methods, required field additions, type changes, signature changes
256    #[allow(dead_code)]
257    fn classify_proto_change(mismatch: &Mismatch) -> (bool, bool) {
258        match mismatch.mismatch_type {
259            // Breaking changes
260            MismatchType::EndpointNotFound => (false, true), // Method/service removed
261            MismatchType::TypeMismatch => (false, true),     // Type changed
262            MismatchType::SchemaMismatch => (false, true),   // Signature changed
263            MismatchType::MissingRequiredField => (false, true), // Required field added
264
265            // Additive changes
266            MismatchType::UnexpectedField => {
267                // Check severity - Low severity usually means additive (new method/field)
268                match mismatch.severity {
269                    MismatchSeverity::Low | MismatchSeverity::Info => (true, false),
270                    _ => (false, false), // Medium/High severity unexpected fields might be breaking
271                }
272            }
273
274            // Potentially breaking (depends on context)
275            MismatchType::FormatMismatch | MismatchType::ConstraintViolation => {
276                match mismatch.severity {
277                    MismatchSeverity::Critical | MismatchSeverity::High => (false, true),
278                    _ => (false, false),
279                }
280            }
281
282            // Not applicable for proto changes
283            _ => (false, false),
284        }
285    }
286
287    /// Compare methods between two service descriptors
288    fn diff_methods(
289        &self,
290        old_service: &ServiceDescriptor,
291        new_service: &ServiceDescriptor,
292    ) -> Result<Vec<Mismatch>, ContractError> {
293        let mut mismatches = Vec::new();
294        let service_name = old_service.full_name().to_string();
295
296        // Collect all method names
297        let old_methods: std::collections::HashSet<String> =
298            old_service.methods().map(|m| m.name().to_string()).collect();
299        let new_methods: std::collections::HashSet<String> =
300            new_service.methods().map(|m| m.name().to_string()).collect();
301
302        // Check for removed methods (breaking change)
303        for method_name in &old_methods {
304            if !new_methods.contains(method_name) {
305                let path = format!("{}.{}", service_name, method_name);
306                let mut context = HashMap::new();
307                context.insert("is_additive".to_string(), serde_json::json!(false));
308                context.insert("is_breaking".to_string(), serde_json::json!(true));
309                context.insert("change_category".to_string(), serde_json::json!("method_removed"));
310                context.insert("service".to_string(), serde_json::json!(service_name));
311                context.insert("method".to_string(), serde_json::json!(method_name));
312
313                mismatches.push(Mismatch {
314                    mismatch_type: MismatchType::EndpointNotFound,
315                    path: path.clone(),
316                    method: Some(method_name.clone()),
317                    expected: Some(format!("Method {}.{} should exist", service_name, method_name)),
318                    actual: Some("Method removed".to_string()),
319                    description: format!("Method {}.{} was removed", service_name, method_name),
320                    severity: MismatchSeverity::Critical,
321                    confidence: 1.0,
322                    context,
323                });
324            }
325        }
326
327        // Check for added methods (non-breaking, additive)
328        for method_name in &new_methods {
329            if !old_methods.contains(method_name) {
330                let path = format!("{}.{}", service_name, method_name);
331                let mut context = HashMap::new();
332                context.insert("is_additive".to_string(), serde_json::json!(true));
333                context.insert("is_breaking".to_string(), serde_json::json!(false));
334                context.insert("change_category".to_string(), serde_json::json!("method_added"));
335                context.insert("service".to_string(), serde_json::json!(service_name));
336                context.insert("method".to_string(), serde_json::json!(method_name));
337
338                mismatches.push(Mismatch {
339                    mismatch_type: MismatchType::UnexpectedField,
340                    path: path.clone(),
341                    method: Some(method_name.clone()),
342                    expected: None,
343                    actual: Some(format!("New method {}.{}", service_name, method_name)),
344                    description: format!("New method {}.{} was added", service_name, method_name),
345                    severity: MismatchSeverity::Low,
346                    confidence: 1.0,
347                    context,
348                });
349            }
350        }
351
352        // Compare method signatures for methods that exist in both
353        for method_name in old_methods.intersection(&new_methods) {
354            let old_method = old_service
355                .methods()
356                .find(|m| m.name() == method_name)
357                .ok_or_else(|| ContractError::OperationNotFound(method_name.clone()))?;
358            let new_method = new_service
359                .methods()
360                .find(|m| m.name() == method_name)
361                .ok_or_else(|| ContractError::OperationNotFound(method_name.clone()))?;
362
363            // Compare method signatures (input/output types, streaming flags)
364            let method_mismatches =
365                Self::diff_method_signatures(&old_method, &new_method, &service_name)?;
366            mismatches.extend(method_mismatches);
367
368            // Compare message fields if message types are the same
369            // This helps detect field-level changes even when message type names match
370            let old_input = old_method.input();
371            let new_input = new_method.input();
372            if old_input.full_name() == new_input.full_name() {
373                let input_field_mismatches = Self::diff_message_fields(
374                    &old_input,
375                    &new_input,
376                    &format!("{}.{}.input", service_name, method_name),
377                    &service_name,
378                    Some(method_name),
379                )?;
380                mismatches.extend(input_field_mismatches);
381            }
382
383            let old_output = old_method.output();
384            let new_output = new_method.output();
385            if old_output.full_name() == new_output.full_name() {
386                let output_field_mismatches = Self::diff_message_fields(
387                    &old_output,
388                    &new_output,
389                    &format!("{}.{}.output", service_name, method_name),
390                    &service_name,
391                    Some(method_name),
392                )?;
393                mismatches.extend(output_field_mismatches);
394            }
395        }
396
397        Ok(mismatches)
398    }
399
400    /// Compare method signatures (input/output types, streaming flags)
401    fn diff_method_signatures(
402        old_method: &MethodDescriptor,
403        new_method: &MethodDescriptor,
404        service_name: &str,
405    ) -> Result<Vec<Mismatch>, ContractError> {
406        let mut mismatches = Vec::new();
407        let method_name = old_method.name();
408        let path = format!("{}.{}", service_name, method_name);
409
410        // Check input type changes (breaking change)
411        if old_method.input().full_name() != new_method.input().full_name() {
412            let mut context = HashMap::new();
413            context.insert("is_additive".to_string(), serde_json::json!(false));
414            context.insert("is_breaking".to_string(), serde_json::json!(true));
415            context.insert("change_category".to_string(), serde_json::json!("input_type_changed"));
416            context.insert("service".to_string(), serde_json::json!(service_name));
417            context.insert("method".to_string(), serde_json::json!(method_name));
418            context
419                .insert("old_type".to_string(), serde_json::json!(old_method.input().full_name()));
420            context
421                .insert("new_type".to_string(), serde_json::json!(new_method.input().full_name()));
422
423            mismatches.push(Mismatch {
424                mismatch_type: MismatchType::TypeMismatch,
425                path: format!("{}.input", path),
426                method: Some(method_name.to_string()),
427                expected: Some(old_method.input().full_name().to_string()),
428                actual: Some(new_method.input().full_name().to_string()),
429                description: format!(
430                    "Input type changed from {} to {}",
431                    old_method.input().full_name(),
432                    new_method.input().full_name()
433                ),
434                severity: MismatchSeverity::High,
435                confidence: 1.0,
436                context,
437            });
438        }
439
440        // Check output type changes (breaking change)
441        if old_method.output().full_name() != new_method.output().full_name() {
442            let mut context = HashMap::new();
443            context.insert("is_additive".to_string(), serde_json::json!(false));
444            context.insert("is_breaking".to_string(), serde_json::json!(true));
445            context.insert("change_category".to_string(), serde_json::json!("output_type_changed"));
446            context.insert("service".to_string(), serde_json::json!(service_name));
447            context.insert("method".to_string(), serde_json::json!(method_name));
448            context
449                .insert("old_type".to_string(), serde_json::json!(old_method.output().full_name()));
450            context
451                .insert("new_type".to_string(), serde_json::json!(new_method.output().full_name()));
452
453            mismatches.push(Mismatch {
454                mismatch_type: MismatchType::TypeMismatch,
455                path: format!("{}.output", path),
456                method: Some(method_name.to_string()),
457                expected: Some(old_method.output().full_name().to_string()),
458                actual: Some(new_method.output().full_name().to_string()),
459                description: format!(
460                    "Output type changed from {} to {}",
461                    old_method.output().full_name(),
462                    new_method.output().full_name()
463                ),
464                severity: MismatchSeverity::High,
465                confidence: 1.0,
466                context,
467            });
468        }
469
470        // Check streaming flag changes (breaking change)
471        if old_method.is_client_streaming() != new_method.is_client_streaming() {
472            let mut context = HashMap::new();
473            context.insert("is_additive".to_string(), serde_json::json!(false));
474            context.insert("is_breaking".to_string(), serde_json::json!(true));
475            context.insert(
476                "change_category".to_string(),
477                serde_json::json!("streaming_config_changed"),
478            );
479            context.insert("service".to_string(), serde_json::json!(service_name));
480            context.insert("method".to_string(), serde_json::json!(method_name));
481            context.insert("streaming_type".to_string(), serde_json::json!("client"));
482            context.insert(
483                "old_value".to_string(),
484                serde_json::json!(old_method.is_client_streaming()),
485            );
486            context.insert(
487                "new_value".to_string(),
488                serde_json::json!(new_method.is_client_streaming()),
489            );
490
491            mismatches.push(Mismatch {
492                mismatch_type: MismatchType::SchemaMismatch,
493                path: path.clone(),
494                method: Some(method_name.to_string()),
495                expected: Some(format!("Client streaming: {}", old_method.is_client_streaming())),
496                actual: Some(format!("Client streaming: {}", new_method.is_client_streaming())),
497                description: format!(
498                    "Client streaming flag changed for {}.{}",
499                    service_name, method_name
500                ),
501                severity: MismatchSeverity::Critical,
502                confidence: 1.0,
503                context,
504            });
505        }
506
507        if old_method.is_server_streaming() != new_method.is_server_streaming() {
508            let mut context = HashMap::new();
509            context.insert("is_additive".to_string(), serde_json::json!(false));
510            context.insert("is_breaking".to_string(), serde_json::json!(true));
511            context.insert(
512                "change_category".to_string(),
513                serde_json::json!("streaming_config_changed"),
514            );
515            context.insert("service".to_string(), serde_json::json!(service_name));
516            context.insert("method".to_string(), serde_json::json!(method_name));
517            context.insert("streaming_type".to_string(), serde_json::json!("server"));
518            context.insert(
519                "old_value".to_string(),
520                serde_json::json!(old_method.is_server_streaming()),
521            );
522            context.insert(
523                "new_value".to_string(),
524                serde_json::json!(new_method.is_server_streaming()),
525            );
526
527            mismatches.push(Mismatch {
528                mismatch_type: MismatchType::SchemaMismatch,
529                path: path.clone(),
530                method: Some(method_name.to_string()),
531                expected: Some(format!("Server streaming: {}", old_method.is_server_streaming())),
532                actual: Some(format!("Server streaming: {}", new_method.is_server_streaming())),
533                description: format!(
534                    "Server streaming flag changed for {}.{}",
535                    service_name, method_name
536                ),
537                severity: MismatchSeverity::Critical,
538                confidence: 1.0,
539                context,
540            });
541        }
542
543        Ok(mismatches)
544    }
545
546    /// Compare message fields between two message descriptors
547    ///
548    /// Detects:
549    /// - Field removals (breaking)
550    /// - Field additions (additive if optional, breaking if required)
551    /// - Field type changes (breaking)
552    /// - Field number changes (breaking)
553    fn diff_message_fields(
554        old_message: &MessageDescriptor,
555        new_message: &MessageDescriptor,
556        path_prefix: &str,
557        service_name: &str,
558        method_name: Option<&str>,
559    ) -> Result<Vec<Mismatch>, ContractError> {
560        let mut mismatches = Vec::new();
561
562        // Collect field information
563        let old_fields: HashMap<u32, prost_reflect::FieldDescriptor> =
564            old_message.fields().map(|f| (f.number(), f)).collect();
565        let new_fields: HashMap<u32, prost_reflect::FieldDescriptor> =
566            new_message.fields().map(|f| (f.number(), f)).collect();
567
568        // Check for removed fields (breaking change)
569        for (field_number, old_field) in &old_fields {
570            if !new_fields.contains_key(field_number) {
571                let field_path = format!("{}.field_{}", path_prefix, field_number);
572                let mut context = HashMap::new();
573                context.insert("is_additive".to_string(), serde_json::json!(false));
574                context.insert("is_breaking".to_string(), serde_json::json!(true));
575                context.insert("change_category".to_string(), serde_json::json!("field_removed"));
576                context.insert("service".to_string(), serde_json::json!(service_name));
577                if let Some(method) = method_name {
578                    context.insert("method".to_string(), serde_json::json!(method));
579                }
580                context.insert("field_number".to_string(), serde_json::json!(*field_number));
581                context.insert("field_name".to_string(), serde_json::json!(old_field.name()));
582                context.insert(
583                    "field_type".to_string(),
584                    serde_json::json!(format!("{:?}", old_field.kind())),
585                );
586
587                mismatches.push(Mismatch {
588                    mismatch_type: MismatchType::EndpointNotFound,
589                    path: field_path.clone(),
590                    method: method_name.map(|s| s.to_string()),
591                    expected: Some(format!(
592                        "Field {} ({}) should exist",
593                        old_field.name(),
594                        field_number
595                    )),
596                    actual: Some("Field removed".to_string()),
597                    description: format!(
598                        "Field {} (number {}) was removed from message {}",
599                        old_field.name(),
600                        field_number,
601                        old_message.full_name()
602                    ),
603                    severity: MismatchSeverity::High,
604                    confidence: 1.0,
605                    context,
606                });
607            }
608        }
609
610        // Check for added fields
611        for (field_number, new_field) in &new_fields {
612            if !old_fields.contains_key(field_number) {
613                let field_path = format!("{}.field_{}", path_prefix, field_number);
614                let mut context = HashMap::new();
615                // In proto3, all fields are optional by default, so new fields are additive
616                // In proto2, if the field is required, it's breaking
617                let is_required = new_field.cardinality() == prost_reflect::Cardinality::Required;
618                context.insert("is_additive".to_string(), serde_json::json!(!is_required));
619                context.insert("is_breaking".to_string(), serde_json::json!(is_required));
620                context.insert(
621                    "change_category".to_string(),
622                    serde_json::json!(if is_required {
623                        "required_field_added"
624                    } else {
625                        "field_added"
626                    }),
627                );
628                context.insert("service".to_string(), serde_json::json!(service_name));
629                if let Some(method) = method_name {
630                    context.insert("method".to_string(), serde_json::json!(method));
631                }
632                context.insert("field_number".to_string(), serde_json::json!(*field_number));
633                context.insert("field_name".to_string(), serde_json::json!(new_field.name()));
634                context.insert(
635                    "field_type".to_string(),
636                    serde_json::json!(format!("{:?}", new_field.kind())),
637                );
638                context.insert("is_required".to_string(), serde_json::json!(is_required));
639
640                mismatches.push(Mismatch {
641                    mismatch_type: if is_required {
642                        MismatchType::MissingRequiredField
643                    } else {
644                        MismatchType::UnexpectedField
645                    },
646                    path: field_path.clone(),
647                    method: method_name.map(|s| s.to_string()),
648                    expected: None,
649                    actual: Some(format!(
650                        "New field {} (number {})",
651                        new_field.name(),
652                        field_number
653                    )),
654                    description: format!(
655                        "New field {} (number {}) was added to message {} ({})",
656                        new_field.name(),
657                        field_number,
658                        new_message.full_name(),
659                        if is_required {
660                            "required - breaking"
661                        } else {
662                            "optional - additive"
663                        }
664                    ),
665                    severity: if is_required {
666                        MismatchSeverity::High
667                    } else {
668                        MismatchSeverity::Low
669                    },
670                    confidence: 1.0,
671                    context,
672                });
673            }
674        }
675
676        // Check for field type changes (same field number, different type)
677        for (field_number, old_field) in &old_fields {
678            if let Some(new_field) = new_fields.get(field_number) {
679                // Check if field name changed (breaking)
680                if old_field.name() != new_field.name() {
681                    let field_path = format!("{}.field_{}", path_prefix, field_number);
682                    let mut context = HashMap::new();
683                    context.insert("is_additive".to_string(), serde_json::json!(false));
684                    context.insert("is_breaking".to_string(), serde_json::json!(true));
685                    context.insert(
686                        "change_category".to_string(),
687                        serde_json::json!("field_name_changed"),
688                    );
689                    context.insert("service".to_string(), serde_json::json!(service_name));
690                    if let Some(method) = method_name {
691                        context.insert("method".to_string(), serde_json::json!(method));
692                    }
693                    context.insert("field_number".to_string(), serde_json::json!(*field_number));
694                    context.insert("old_name".to_string(), serde_json::json!(old_field.name()));
695                    context.insert("new_name".to_string(), serde_json::json!(new_field.name()));
696
697                    mismatches.push(Mismatch {
698                        mismatch_type: MismatchType::SchemaMismatch,
699                        path: field_path.clone(),
700                        method: method_name.map(|s| s.to_string()),
701                        expected: Some(format!("Field name: {}", old_field.name())),
702                        actual: Some(format!("Field name: {}", new_field.name())),
703                        description: format!(
704                            "Field name changed from {} to {} (field number {})",
705                            old_field.name(),
706                            new_field.name(),
707                            field_number
708                        ),
709                        severity: MismatchSeverity::High,
710                        confidence: 1.0,
711                        context,
712                    });
713                }
714
715                // Check if field type changed (breaking)
716                if old_field.kind() != new_field.kind() {
717                    let field_path = format!("{}.field_{}", path_prefix, field_number);
718                    let mut context = HashMap::new();
719                    context.insert("is_additive".to_string(), serde_json::json!(false));
720                    context.insert("is_breaking".to_string(), serde_json::json!(true));
721                    context.insert(
722                        "change_category".to_string(),
723                        serde_json::json!("field_type_changed"),
724                    );
725                    context.insert("service".to_string(), serde_json::json!(service_name));
726                    if let Some(method) = method_name {
727                        context.insert("method".to_string(), serde_json::json!(method));
728                    }
729                    context.insert("field_number".to_string(), serde_json::json!(*field_number));
730                    context.insert("field_name".to_string(), serde_json::json!(old_field.name()));
731                    context.insert(
732                        "old_type".to_string(),
733                        serde_json::json!(format!("{:?}", old_field.kind())),
734                    );
735                    context.insert(
736                        "new_type".to_string(),
737                        serde_json::json!(format!("{:?}", new_field.kind())),
738                    );
739
740                    mismatches.push(Mismatch {
741                        mismatch_type: MismatchType::TypeMismatch,
742                        path: field_path.clone(),
743                        method: method_name.map(|s| s.to_string()),
744                        expected: Some(format!("Field type: {:?}", old_field.kind())),
745                        actual: Some(format!("Field type: {:?}", new_field.kind())),
746                        description: format!(
747                            "Field {} type changed from {:?} to {:?}",
748                            old_field.name(),
749                            old_field.kind(),
750                            new_field.kind()
751                        ),
752                        severity: MismatchSeverity::High,
753                        confidence: 1.0,
754                        context,
755                    });
756                }
757
758                // Check if cardinality changed (e.g., optional to required - breaking)
759                if old_field.cardinality() != new_field.cardinality() {
760                    let old_cardinality = old_field.cardinality();
761                    let new_cardinality = new_field.cardinality();
762                    let is_breaking = matches!(
763                        (old_cardinality, new_cardinality),
764                        (
765                            prost_reflect::Cardinality::Optional
766                                | prost_reflect::Cardinality::Repeated,
767                            prost_reflect::Cardinality::Required
768                        )
769                    );
770
771                    let field_path = format!("{}.field_{}", path_prefix, field_number);
772                    let mut context = HashMap::new();
773                    context.insert("is_additive".to_string(), serde_json::json!(!is_breaking));
774                    context.insert("is_breaking".to_string(), serde_json::json!(is_breaking));
775                    context.insert(
776                        "change_category".to_string(),
777                        serde_json::json!("field_cardinality_changed"),
778                    );
779                    context.insert("service".to_string(), serde_json::json!(service_name));
780                    if let Some(method) = method_name {
781                        context.insert("method".to_string(), serde_json::json!(method));
782                    }
783                    context.insert("field_number".to_string(), serde_json::json!(*field_number));
784                    context.insert("field_name".to_string(), serde_json::json!(old_field.name()));
785                    context.insert(
786                        "old_cardinality".to_string(),
787                        serde_json::json!(format!("{:?}", old_cardinality)),
788                    );
789                    context.insert(
790                        "new_cardinality".to_string(),
791                        serde_json::json!(format!("{:?}", new_cardinality)),
792                    );
793
794                    mismatches.push(Mismatch {
795                        mismatch_type: if is_breaking {
796                            MismatchType::MissingRequiredField
797                        } else {
798                            MismatchType::SchemaMismatch
799                        },
800                        path: field_path.clone(),
801                        method: method_name.map(|s| s.to_string()),
802                        expected: Some(format!("Cardinality: {:?}", old_cardinality)),
803                        actual: Some(format!("Cardinality: {:?}", new_cardinality)),
804                        description: format!(
805                            "Field {} cardinality changed from {:?} to {:?} ({})",
806                            old_field.name(),
807                            old_cardinality,
808                            new_cardinality,
809                            if is_breaking {
810                                "breaking"
811                            } else {
812                                "non-breaking"
813                            }
814                        ),
815                        severity: if is_breaking {
816                            MismatchSeverity::High
817                        } else {
818                            MismatchSeverity::Medium
819                        },
820                        confidence: 1.0,
821                        context,
822                    });
823                }
824            }
825        }
826
827        Ok(mismatches)
828    }
829}
830
831#[async_trait::async_trait]
832impl ProtocolContract for GrpcContract {
833    fn protocol(&self) -> Protocol {
834        Protocol::Grpc
835    }
836
837    fn contract_id(&self) -> &str {
838        &self.contract_id
839    }
840
841    fn version(&self) -> &str {
842        &self.version
843    }
844
845    fn operations(&self) -> Vec<ContractOperation> {
846        self.operations_cache.values().cloned().collect()
847    }
848
849    fn get_operation(&self, operation_id: &str) -> Option<&ContractOperation> {
850        self.operations_cache.get(operation_id)
851    }
852
853    async fn diff(
854        &self,
855        other: &dyn ProtocolContract,
856    ) -> Result<ContractDiffResult, ContractError> {
857        // Ensure the other contract is also a gRPC contract
858        if other.protocol() != Protocol::Grpc {
859            return Err(ContractError::UnsupportedProtocol(other.protocol()));
860        }
861
862        // Try to downcast to GrpcContract
863        // Since we can't use downcast_ref on trait objects, we'll need to use a different approach
864        // For now, we'll require that contracts of the same protocol can be compared
865        // In a full implementation, we might use a type-erased approach or require
866        // contracts to provide a way to access their internal representation
867
868        // This is a limitation of the current design - we need a way to compare
869        // contracts of the same protocol type
870        Err(ContractError::Other(
871            "Direct comparison of GrpcContract instances requires type information. \
872             Use GrpcContract::diff_services() for comparing two GrpcContract instances."
873                .to_string(),
874        ))
875    }
876
877    async fn validate(
878        &self,
879        operation_id: &str,
880        request: &ContractRequest,
881    ) -> Result<ValidationResult, ContractError> {
882        // Check if the operation exists
883        let Some(method) = self.methods.get(operation_id) else {
884            return Ok(ValidationResult {
885                valid: false,
886                errors: vec![ValidationError {
887                    message: format!("Method {} not found in contract", operation_id),
888                    path: Some(operation_id.to_string()),
889                    code: Some("METHOD_NOT_FOUND".to_string()),
890                }],
891                warnings: Vec::new(),
892            });
893        };
894
895        // Get the input message descriptor for this method
896        let input_message = method.input();
897        let message_name = input_message.full_name().to_string();
898        let field_count = input_message.fields().count();
899
900        // Validate the payload against the protobuf schema
901        let mut errors = Vec::new();
902        let mut warnings = Vec::new();
903
904        // Try to deserialize the payload as a protobuf message
905        // For gRPC, the payload should be a serialized protobuf message
906        if request.payload.is_empty() {
907            // Empty payload might be valid for methods with no input
908            if field_count > 0 {
909                // Check if all fields are optional (proto3 has no required fields by default)
910                // But we can still validate that the message structure is correct
911                warnings.push("Empty payload provided for method with input message".to_string());
912            }
913        } else {
914            // Attempt to deserialize the payload
915            // Convert Vec<u8> to bytes::Bytes for prost_reflect
916            use bytes::Bytes;
917            let payload_bytes = Bytes::from(request.payload.clone());
918
919            // Clone the input_message descriptor since decode takes ownership
920            let input_message_clone = input_message;
921            match prost_reflect::DynamicMessage::decode(input_message_clone, payload_bytes) {
922                Ok(_message) => {
923                    // Validate required fields (proto2) or check field presence
924                    // In proto3, all fields are optional, but we can still validate types
925                    // prost_reflect handles type validation during deserialization
926                    // If we got here, the message structure is valid
927                    // Field validation is handled by prost_reflect during deserialization
928                    // If deserialization succeeded, the types are correct
929
930                    // Check for unknown fields (fields not in the schema)
931                    // This is handled by prost_reflect during deserialization
932                    // If deserialization succeeded, the message structure is valid
933                }
934                Err(e) => {
935                    // Deserialization failed - this is a validation error
936                    errors.push(ValidationError {
937                        message: format!(
938                            "Failed to deserialize protobuf message: {}. Expected message type: {}",
939                            e, message_name
940                        ),
941                        path: Some(operation_id.to_string()),
942                        code: Some("DESERIALIZATION_ERROR".to_string()),
943                    });
944                }
945            }
946        }
947
948        // Validate streaming configuration
949        if method.is_client_streaming() && !request.metadata.contains_key("streaming") {
950            warnings.push(
951                "Method is client-streaming but request doesn't indicate streaming".to_string(),
952            );
953        }
954
955        Ok(ValidationResult {
956            valid: errors.is_empty(),
957            errors,
958            warnings,
959        })
960    }
961
962    fn get_schema(&self, operation_id: &str) -> Option<serde_json::Value> {
963        self.methods.get(operation_id).map(|method| {
964            serde_json::json!({
965                "input": {
966                    "type": method.input().full_name(),
967                    "streaming": method.is_client_streaming(),
968                },
969                "output": {
970                    "type": method.output().full_name(),
971                    "streaming": method.is_server_streaming(),
972                },
973            })
974        })
975    }
976
977    fn to_json(&self) -> Result<serde_json::Value, ContractError> {
978        let operations: Vec<serde_json::Value> = self
979            .operations()
980            .iter()
981            .map(|op| {
982                serde_json::json!({
983                    "id": op.id,
984                    "name": op.name,
985                    "type": op.operation_type,
986                    "input_schema": op.input_schema,
987                    "output_schema": op.output_schema,
988                })
989            })
990            .collect();
991
992        Ok(serde_json::json!({
993            "contract_id": self.contract_id,
994            "version": self.version,
995            "protocol": "grpc",
996            "services": self.services.keys().collect::<Vec<_>>(),
997            "operations": operations,
998            "metadata": self.metadata,
999        }))
1000    }
1001}
1002
1003/// Helper function to compare two GrpcContract instances
1004pub fn diff_grpc_contracts(
1005    old_contract: &GrpcContract,
1006    new_contract: &GrpcContract,
1007) -> Result<ContractDiffResult, ContractError> {
1008    old_contract.diff_services(new_contract)
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013    #[test]
1014    fn test_grpc_contract_creation() {
1015        // This test would require a sample descriptor set
1016        // For now, we'll just test that the structure compiles
1017        // In a full implementation, we'd create a test proto file and compile it
1018    }
1019}