cadi_core/ghost/
resolver.rs

1use crate::graph::GraphStore;
2use super::policy::ExpansionPolicy;
3use super::analyzer::DependencyAnalyzer;
4
5/// Ghost Import Resolver
6///
7/// Automatically expands atom context to include necessary dependencies
8/// so LLMs don't hallucinate missing types.
9pub struct GhostResolver {
10    graph: GraphStore,
11}
12
13#[derive(Debug)]
14pub struct ExpansionResult {
15    /// All atoms to include (original + ghosts)
16    pub atoms: Vec<String>,
17    /// Which atoms were "ghost" additions
18    pub ghost_atoms: Vec<String>,
19    /// Truncated due to limits?
20    pub truncated: bool,
21    /// Total token estimate
22    pub total_tokens: usize,
23    /// Explanation of what was included and why
24    pub explanation: String,
25}
26
27impl GhostResolver {
28    pub fn new(graph: GraphStore) -> Self {
29        Self { graph }
30    }
31
32    pub fn analyzer(&self) -> DependencyAnalyzer<'_> {
33        DependencyAnalyzer::new(&self.graph)
34    }
35
36    /// Resolve ghost imports for a set of atoms
37    pub async fn resolve(&self, atom_ids: &[String]) -> Result<ExpansionResult, Box<dyn std::error::Error + Send + Sync>> {
38        self.resolve_with_policy(atom_ids, &ExpansionPolicy::default()).await
39    }
40
41    /// Resolve with custom policy
42    pub async fn resolve_with_policy(
43        &self,
44        atom_ids: &[String],
45        policy: &ExpansionPolicy,
46    ) -> Result<ExpansionResult, Box<dyn std::error::Error + Send + Sync>> {
47        // Simulate the expansion first
48        let simulation = self.analyzer().simulate_expansion(atom_ids, policy)?;
49
50        // Separate original vs ghost atoms
51        let mut ghost_atoms = Vec::new();
52        let mut explanations = Vec::new();
53
54        for atom_id in &simulation.included_atoms {
55            if !atom_ids.contains(atom_id) {
56                ghost_atoms.push(atom_id.clone());
57
58                // Find why this atom was included
59                if let Some(reason) = self.find_inclusion_reason(atom_id, atom_ids, policy).await? {
60                    explanations.push(reason);
61                }
62            }
63        }
64
65        Ok(ExpansionResult {
66            atoms: simulation.included_atoms,
67            ghost_atoms,
68            truncated: simulation.truncated,
69            total_tokens: simulation.total_tokens,
70            explanation: explanations.join("\n"),
71        })
72    }
73
74    /// Get optimal policy for a set of atoms
75    pub async fn suggest_policy(&self, atom_ids: &[String]) -> Result<ExpansionPolicy, Box<dyn std::error::Error + Send + Sync>> {
76        let analysis = self.analyzer().analyze_dependencies(atom_ids)?;
77
78        // Calculate average dependencies per atom
79        let avg_deps = analysis.iter()
80            .map(|info| info.dependencies.len())
81            .sum::<usize>() as f64 / analysis.len() as f64;
82
83        // Calculate total tokens
84        let total_tokens = analysis.iter()
85            .map(|info| info.token_estimate)
86            .sum::<usize>();
87
88        // Suggest policy based on complexity
89        let mut policy = ExpansionPolicy::default();
90
91        if avg_deps < 2.0 && total_tokens < 1000 {
92            // Simple case - be conservative
93            policy = ExpansionPolicy::conservative();
94        } else if avg_deps > 5.0 || total_tokens > 5000 {
95            // Complex case - be more aggressive but careful
96            policy.max_depth = 3;
97            policy.max_atoms = 30;
98            policy.max_tokens = 6000;
99        }
100
101        Ok(policy)
102    }
103
104    async fn find_inclusion_reason(
105        &self,
106        atom_id: &str,
107        original_atoms: &[String],
108        policy: &ExpansionPolicy,
109    ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
110        // Find which original atom(s) led to this inclusion
111        for original_id in original_atoms {
112            let deps = self.graph.get_dependencies(original_id)?;
113            for (edge_type, dep_id) in deps {
114                if dep_id == atom_id && policy.follow_edges.contains(&edge_type) {
115                    return Ok(Some(format!(
116                        "Added '{}' because '{}' references it via {:?}",
117                        atom_id, original_id, edge_type
118                    )));
119                }
120            }
121        }
122
123        // Check transitive dependencies
124        for original_id in original_atoms {
125            if let Some(path) = self.find_dependency_path(original_id, atom_id, policy, 3)? {
126                return Ok(Some(format!(
127                    "Added '{}' through dependency chain: {}",
128                    atom_id, path
129                )));
130            }
131        }
132
133        Ok(None)
134    }
135
136    fn find_dependency_path(
137        &self,
138        from: &str,
139        to: &str,
140        policy: &ExpansionPolicy,
141        max_depth: usize,
142    ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
143        if max_depth == 0 {
144            return Ok(None);
145        }
146
147        let deps = self.graph.get_dependencies(from)?;
148        for (edge_type, dep_id) in deps {
149            if dep_id == to && policy.follow_edges.contains(&edge_type) {
150                return Ok(Some(format!("{} -> {}", from, to)));
151            }
152
153            if let Some(path) = self.find_dependency_path(&dep_id, to, policy, max_depth - 1)? {
154                return Ok(Some(format!("{} -> {}", from, path)));
155            }
156        }
157
158        Ok(None)
159    }
160}