use ryo_analysis::context::AnalysisContext;
use ryo_analysis::{SymbolId, SymbolKind};
use ryo_executor::{SpecRelation, SpecRelationKind};
use super::{is_framework_type, SpecSuggest};
use crate::{
MutationSpec, OpportunityId, SafetyLevel, Suggest, SuggestCategory, SuggestError,
SuggestLocation, SuggestOpportunity, SuggestResult,
};
pub struct BidirectionalRelation {
suffix: String,
default_group: String,
}
impl BidirectionalRelation {
pub fn new() -> Self {
Self {
suffix: "Spec".to_string(),
default_group: "DomainGroup".to_string(),
}
}
pub fn with_suffix(suffix: impl Into<String>) -> Self {
Self {
suffix: suffix.into(),
default_group: "DomainGroup".to_string(),
}
}
pub fn with_group(mut self, group: impl Into<String>) -> Self {
self.default_group = group.into();
self
}
fn find_spec_for_type(
&self,
ctx: &AnalysisContext,
type_name: &str,
) -> Option<(SymbolId, String)> {
let spec_name = format!("{}{}", type_name, self.suffix);
for symbol_id in ctx.registry.iter_by_kind(SymbolKind::TypeAlias) {
if let Some(path) = ctx.registry.path(symbol_id) {
if path.name() == spec_name {
return Some((symbol_id, spec_name));
}
}
}
None
}
fn extract_relation_targets(
&self,
ctx: &AnalysisContext,
spec_id: SymbolId,
base_type: &str,
) -> Vec<String> {
let typeflow = ctx.typeflow_graph();
let mut targets = Vec::new();
for used_id in typeflow.types_used_by(spec_id) {
if let Some(path) = ctx.registry.path(used_id) {
let kind = ctx.registry.kind(used_id);
if !matches!(kind, Some(SymbolKind::Struct) | Some(SymbolKind::Enum)) {
continue;
}
let used_name = path.name();
if is_framework_type(used_name) || used_name == base_type {
continue;
}
targets.push(used_name.to_string());
}
}
targets
}
fn spec_has_relation_to(
&self,
ctx: &AnalysisContext,
spec_id: SymbolId,
target_type: &str,
) -> bool {
let typeflow = ctx.typeflow_graph();
for used_id in typeflow.types_used_by(spec_id) {
if let Some(path) = ctx.registry.path(used_id) {
let kind = ctx.registry.kind(used_id);
if matches!(kind, Some(SymbolKind::Struct) | Some(SymbolKind::Enum))
&& path.name() == target_type
{
return true;
}
}
}
false
}
}
impl Default for BidirectionalRelation {
fn default() -> Self {
Self::new()
}
}
const RULE_CODE: &str = "RS008";
impl SpecSuggest for BidirectionalRelation {
fn spec_suffix(&self) -> &str {
&self.suffix
}
}
impl Suggest for BidirectionalRelation {
fn name(&self) -> &'static str {
"bidirectional-relation"
}
fn description(&self) -> &str {
"Ensures consistency of bidirectional Spec relations"
}
fn category(&self) -> SuggestCategory {
SuggestCategory::Pattern
}
fn safety_level(&self) -> SafetyLevel {
SafetyLevel::Confirm }
fn priority_weight(&self) -> f32 {
0.9 }
fn detect(&self, ctx: &AnalysisContext, symbols: &[SymbolId]) -> Vec<SuggestOpportunity> {
use super::{create_spec_opportunity, SpecDetails};
let mut opportunities = Vec::new();
let mut next_id = 0u32;
let mut reported_pairs: std::collections::HashSet<(String, String)> =
std::collections::HashSet::new();
let symbols_to_check: Vec<SymbolId> = if symbols.is_empty() {
ctx.registry.iter_by_kind(SymbolKind::TypeAlias).collect()
} else {
symbols.to_vec()
};
for spec_id in symbols_to_check {
let path = match ctx.registry.path(spec_id) {
Some(p) => p,
None => continue,
};
let alias_name = path.name();
if !self.is_spec_alias(alias_name) {
continue;
}
let source_type = match self.extract_base_type(alias_name) {
Some(bt) => bt.to_string(),
None => continue,
};
let relation_targets = self.extract_relation_targets(ctx, spec_id, &source_type);
for target_type in relation_targets {
let (target_spec_id, target_spec_name) =
match self.find_spec_for_type(ctx, &target_type) {
Some((id, name)) => (id, name),
None => continue, };
if self.spec_has_relation_to(ctx, target_spec_id, &source_type) {
continue; }
let pair_key = if source_type < target_type {
(source_type.clone(), target_type.clone())
} else {
(target_type.clone(), source_type.clone())
};
if reported_pairs.contains(&pair_key) {
continue;
}
reported_pairs.insert(pair_key);
let Some(location) = SuggestLocation::from_context(ctx, target_spec_id) else {
continue;
};
let opp = create_spec_opportunity(
RULE_CODE,
OpportunityId::new(next_id),
vec![target_spec_id, spec_id],
location,
format!(
"`{}` has relation to `{}`, but `{}` has no relation back to `{}`",
alias_name, target_type, target_spec_name, source_type
),
0.75, SpecDetails {
alias_name: Some(target_spec_name.clone()),
base_type: Some(target_type.clone()),
group: Some(self.default_group.clone()),
related_types: vec![source_type.clone()],
suggestion: Some(format!(
"Add `RelatedTo<{}>` to {}",
source_type, target_spec_name
)),
},
);
opportunities.push(opp);
next_id += 1;
}
}
opportunities
}
fn to_mutation_specs(
&self,
ctx: &AnalysisContext,
opportunity: &SuggestOpportunity,
) -> SuggestResult<Vec<MutationSpec>> {
let target_spec_id = match opportunity.targets.first() {
Some(id) => *id,
None => return Ok(Vec::new()),
};
let spec_path = match ctx.registry.path(target_spec_id) {
Some(p) => p,
None => return Ok(Vec::new()),
};
let alias_name = spec_path.name().to_string();
let base_type = match self.extract_base_type(&alias_name) {
Some(bt) => bt.to_string(),
None => return Ok(Vec::new()),
};
let module_path =
self.get_module_path(spec_path)
.ok_or_else(|| SuggestError::ModulePathResolution {
path: spec_path.to_string(),
})?;
let module_id = match ctx.registry.lookup(&module_path) {
Some(id) => id,
None => return Ok(Vec::new()),
};
let target_type_id = ctx
.registry
.iter()
.find(|(_, path)| path.name() == base_type)
.map(|(id, _)| id);
let target_type_id = match target_type_id {
Some(id) => id,
None => return Ok(Vec::new()),
};
let related_type = match &opportunity.context {
crate::OpportunityContext::Spec { related_types, .. } => related_types.first().cloned(),
_ => None,
};
let source_type = match related_type {
Some(t) => t,
None => return Ok(Vec::new()),
};
let existing_relations = self.extract_relation_targets(ctx, target_spec_id, &base_type);
let mut relations: Vec<SpecRelation> = existing_relations
.into_iter()
.map(|target| SpecRelation::new(SpecRelationKind::RelatedTo, target))
.collect();
relations.push(SpecRelation::new(SpecRelationKind::RelatedTo, source_type));
Ok(vec![MutationSpec::AddSpec {
type_id: target_type_id,
module_id,
group: self.default_group.clone(),
alias_name: Some(alias_name),
relations,
}])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_spec_alias() {
let rule = BidirectionalRelation::new();
assert!(rule.is_spec_alias("UserSpec"));
assert!(rule.is_spec_alias("OrderSpec"));
assert!(!rule.is_spec_alias("Spec"));
assert!(!rule.is_spec_alias("User"));
}
#[test]
fn test_extract_base_type() {
let rule = BidirectionalRelation::new();
assert_eq!(rule.extract_base_type("UserSpec"), Some("User"));
assert_eq!(rule.extract_base_type("OrderSpec"), Some("Order"));
assert_eq!(rule.extract_base_type("Spec"), None);
}
#[test]
fn test_with_group() {
let rule = BidirectionalRelation::new().with_group("CustomGroup");
assert_eq!(rule.default_group, "CustomGroup");
}
#[test]
fn test_with_suffix() {
let rule = BidirectionalRelation::with_suffix("Definition");
assert_eq!(rule.suffix, "Definition");
assert!(rule.is_spec_alias("UserDefinition"));
assert!(!rule.is_spec_alias("UserSpec"));
}
}