baobao_codegen/generation/
imports.rs1use std::collections::{BTreeSet, HashMap};
4
5use indexmap::IndexMap;
6
7#[derive(Debug, Clone, Default)]
28pub struct ImportCollector {
29 imports: IndexMap<String, BTreeSet<String>>,
31}
32
33impl ImportCollector {
34 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn add(&mut self, module: &str, symbol: &str) {
41 self.imports
42 .entry(module.to_string())
43 .or_default()
44 .insert(symbol.to_string());
45 }
46
47 pub fn add_module(&mut self, module: &str) {
49 self.imports.entry(module.to_string()).or_default();
50 }
51
52 pub fn merge(&mut self, other: &ImportCollector) {
54 for (module, symbols) in &other.imports {
55 let entry = self.imports.entry(module.clone()).or_default();
56 entry.extend(symbols.iter().cloned());
57 }
58 }
59
60 pub fn has_module(&self, module: &str) -> bool {
62 self.imports.contains_key(module)
63 }
64
65 pub fn has_symbol(&self, module: &str, symbol: &str) -> bool {
67 self.imports
68 .get(module)
69 .is_some_and(|symbols| symbols.contains(symbol))
70 }
71
72 pub fn iter(&self) -> impl Iterator<Item = (&str, &BTreeSet<String>)> {
74 self.imports.iter().map(|(k, v)| (k.as_str(), v))
75 }
76
77 pub fn is_empty(&self) -> bool {
79 self.imports.is_empty()
80 }
81
82 pub fn len(&self) -> usize {
84 self.imports.len()
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct DependencySpec {
91 pub version: String,
93 pub features: Vec<String>,
95 pub optional: bool,
97}
98
99impl DependencySpec {
100 pub fn new(version: impl Into<String>) -> Self {
102 Self {
103 version: version.into(),
104 features: Vec::new(),
105 optional: false,
106 }
107 }
108
109 pub fn with_features(mut self, features: impl IntoIterator<Item = impl Into<String>>) -> Self {
111 self.features = features.into_iter().map(Into::into).collect();
112 self
113 }
114
115 pub fn optional(mut self) -> Self {
117 self.optional = true;
118 self
119 }
120}
121
122#[derive(Debug, Clone, Default)]
138pub struct DependencyCollector {
139 deps: HashMap<String, DependencySpec>,
140}
141
142impl DependencyCollector {
143 pub fn new() -> Self {
145 Self::default()
146 }
147
148 pub fn add(&mut self, name: impl Into<String>, spec: DependencySpec) {
150 let name = name.into();
151 self.deps.entry(name).or_insert(spec);
152 }
153
154 pub fn add_simple(&mut self, name: impl Into<String>, version: impl Into<String>) {
156 self.add(name, DependencySpec::new(version));
157 }
158
159 pub fn has(&self, name: &str) -> bool {
161 self.deps.contains_key(name)
162 }
163
164 pub fn get(&self, name: &str) -> Option<&DependencySpec> {
166 self.deps.get(name)
167 }
168
169 pub fn iter(&self) -> impl Iterator<Item = (&str, &DependencySpec)> {
171 self.deps.iter().map(|(k, v)| (k.as_str(), v))
172 }
173
174 pub fn sorted(&self) -> Vec<(&str, &DependencySpec)> {
176 let mut deps: Vec<_> = self.iter().collect();
177 deps.sort_by_key(|(name, _)| *name);
178 deps
179 }
180
181 pub fn is_empty(&self) -> bool {
183 self.deps.is_empty()
184 }
185
186 pub fn len(&self) -> usize {
188 self.deps.len()
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_import_collector_basic() {
198 let mut imports = ImportCollector::new();
199 imports.add("std::io", "Read");
200 imports.add("std::io", "Write");
201 imports.add("std::collections", "HashMap");
202
203 assert!(imports.has_module("std::io"));
204 assert!(imports.has_symbol("std::io", "Read"));
205 assert!(!imports.has_symbol("std::io", "Seek"));
206 assert_eq!(imports.len(), 2);
207 }
208
209 #[test]
210 fn test_import_collector_merge() {
211 let mut a = ImportCollector::new();
212 a.add("std::io", "Read");
213
214 let mut b = ImportCollector::new();
215 b.add("std::io", "Write");
216 b.add("std::fs", "File");
217
218 a.merge(&b);
219
220 assert!(a.has_symbol("std::io", "Read"));
221 assert!(a.has_symbol("std::io", "Write"));
222 assert!(a.has_module("std::fs"));
223 }
224
225 #[test]
226 fn test_dependency_collector() {
227 let mut deps = DependencyCollector::new();
228 deps.add_simple("serde", "1.0");
229 deps.add(
230 "tokio",
231 DependencySpec::new("1").with_features(["rt-multi-thread"]),
232 );
233
234 assert!(deps.has("serde"));
235 assert!(deps.has("tokio"));
236 assert!(!deps.has("async-std"));
237
238 let tokio = deps.get("tokio").unwrap();
239 assert_eq!(tokio.features, vec!["rt-multi-thread"]);
240 }
241}