1use std::collections::HashMap;
8
9use crate::ast::ExprId;
10use crate::intern::InternedStr;
11use crate::type_repr::TypeRepr;
12
13#[derive(Debug, Clone)]
18pub struct TypeConstraint {
19 pub expr_id: ExprId,
21 pub ty: TypeRepr,
23 pub context: String,
25}
26
27impl TypeConstraint {
28 pub fn new(expr_id: ExprId, ty: TypeRepr, context: impl Into<String>) -> Self {
30 Self {
31 expr_id,
32 ty,
33 context: context.into(),
34 }
35 }
36
37 pub fn source_display(&self) -> &'static str {
39 self.ty.source_display()
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct ParamLink {
46 pub expr_id: ExprId,
48 pub param_name: InternedStr,
50 pub context: String,
52}
53
54#[derive(Debug, Clone, Default)]
59pub struct TypeEnv {
60 pub param_constraints: HashMap<InternedStr, Vec<TypeConstraint>>,
62
63 pub expr_constraints: HashMap<ExprId, Vec<TypeConstraint>>,
65
66 pub return_constraints: Vec<TypeConstraint>,
68
69 pub expr_to_param: Vec<ParamLink>,
71
72 pub param_to_exprs: HashMap<InternedStr, Vec<ExprId>>,
77}
78
79impl TypeEnv {
80 pub fn new() -> Self {
82 Self::default()
83 }
84
85 pub fn add_param_constraint(&mut self, param: InternedStr, constraint: TypeConstraint) {
87 self.param_constraints
88 .entry(param)
89 .or_default()
90 .push(constraint);
91 }
92
93 pub fn add_expr_constraint(&mut self, constraint: TypeConstraint) {
95 self.expr_constraints
96 .entry(constraint.expr_id)
97 .or_default()
98 .push(constraint);
99 }
100
101 pub fn add_constraint(&mut self, constraint: TypeConstraint) {
103 self.add_expr_constraint(constraint);
104 }
105
106 pub fn add_return_constraint(&mut self, constraint: TypeConstraint) {
108 self.return_constraints.push(constraint);
109 }
110
111 pub fn link_expr_to_param(&mut self, expr_id: ExprId, param_name: InternedStr, context: impl Into<String>) {
115 self.expr_to_param.push(ParamLink {
117 expr_id,
118 param_name,
119 context: context.into(),
120 });
121
122 self.param_to_exprs
124 .entry(param_name)
125 .or_default()
126 .push(expr_id);
127 }
128
129 pub fn get_param_constraints(&self, param: InternedStr) -> Option<&Vec<TypeConstraint>> {
131 self.param_constraints.get(¶m)
132 }
133
134 pub fn get_expr_constraints(&self, expr_id: ExprId) -> Option<&Vec<TypeConstraint>> {
136 self.expr_constraints.get(&expr_id)
137 }
138
139 pub fn get_linked_param(&self, expr_id: ExprId) -> Option<InternedStr> {
141 self.expr_to_param
142 .iter()
143 .find(|link| link.expr_id == expr_id)
144 .map(|link| link.param_name)
145 }
146
147 pub fn param_constraint_count(&self) -> usize {
149 self.param_constraints.values().map(|v| v.len()).sum()
150 }
151
152 pub fn expr_constraint_count(&self) -> usize {
154 self.expr_constraints.values().map(|v| v.len()).sum()
155 }
156
157 pub fn return_constraint_count(&self) -> usize {
159 self.return_constraints.len()
160 }
161
162 pub fn get_return_type(&self) -> Option<&TypeRepr> {
164 self.return_constraints.first().map(|c| &c.ty)
165 }
166
167 pub fn total_constraint_count(&self) -> usize {
169 self.param_constraint_count() + self.expr_constraint_count() + self.return_constraint_count()
170 }
171
172 pub fn is_empty(&self) -> bool {
174 self.param_constraints.is_empty()
175 && self.expr_constraints.is_empty()
176 && self.return_constraints.is_empty()
177 }
178
179 pub fn merge(&mut self, other: TypeEnv) {
181 for (param, constraints) in other.param_constraints {
182 self.param_constraints
183 .entry(param)
184 .or_default()
185 .extend(constraints);
186 }
187 for (expr_id, constraints) in other.expr_constraints {
188 self.expr_constraints
189 .entry(expr_id)
190 .or_default()
191 .extend(constraints);
192 }
193 self.return_constraints.extend(other.return_constraints);
194 self.expr_to_param.extend(other.expr_to_param);
195
196 for (param, expr_ids) in other.param_to_exprs {
198 self.param_to_exprs
199 .entry(param)
200 .or_default()
201 .extend(expr_ids);
202 }
203 }
204
205 pub fn summary(&self) -> String {
207 format!(
208 "TypeEnv {{ params: {}, exprs: {}, returns: {}, links: {} }}",
209 self.param_constraints.len(),
210 self.expr_constraints.len(),
211 self.return_constraints.len(),
212 self.expr_to_param.len(),
213 )
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::intern::StringInterner;
221 use crate::type_repr::{CTypeSource, CTypeSpecs, InferredType, IntSize, RustTypeRepr, RustTypeSource};
222
223 fn c_int_type() -> TypeRepr {
225 TypeRepr::CType {
226 specs: CTypeSpecs::Int { signed: true, size: IntSize::Int },
227 derived: vec![],
228 source: CTypeSource::Header,
229 }
230 }
231
232 fn rust_c_int_type() -> TypeRepr {
234 TypeRepr::RustType {
235 repr: RustTypeRepr::from_type_string("c_int"),
236 source: RustTypeSource::FnParam { func_name: "test".to_string(), param_index: 0 },
237 }
238 }
239
240 fn apidoc_sv_ptr_type() -> TypeRepr {
242 let interner = StringInterner::new();
243 TypeRepr::from_apidoc_string("SV *", &interner)
244 }
245
246 #[test]
247 fn test_type_env_new() {
248 let env = TypeEnv::new();
249 assert!(env.is_empty());
250 assert_eq!(env.total_constraint_count(), 0);
251 }
252
253 #[test]
254 fn test_add_expr_constraint() {
255 let mut env = TypeEnv::new();
256 let expr_id = ExprId::next();
257
258 let constraint = TypeConstraint::new(
259 expr_id,
260 c_int_type(),
261 "test context",
262 );
263
264 env.add_expr_constraint(constraint);
265
266 assert!(!env.is_empty());
267 assert_eq!(env.expr_constraint_count(), 1);
268 assert_eq!(env.get_expr_constraints(expr_id).unwrap().len(), 1);
269 }
270
271 #[test]
272 fn test_add_multiple_constraints() {
273 let mut env = TypeEnv::new();
274 let expr_id = ExprId::next();
275
276 env.add_constraint(TypeConstraint::new(
278 expr_id,
279 c_int_type(),
280 "from C header",
281 ));
282 env.add_constraint(TypeConstraint::new(
283 expr_id,
284 rust_c_int_type(),
285 "from bindings",
286 ));
287
288 let constraints = env.get_expr_constraints(expr_id).unwrap();
289 assert_eq!(constraints.len(), 2);
290 assert_eq!(constraints[0].source_display(), "c-header");
291 assert_eq!(constraints[1].source_display(), "rust-bindings");
292 }
293
294 #[test]
295 fn test_link_expr_to_param() {
296 let mut env = TypeEnv::new();
297 let expr_id = ExprId::next();
298
299 let mut interner = StringInterner::new();
301 let param_name = interner.intern("x");
302
303 env.link_expr_to_param(expr_id, param_name, "parameter reference");
304
305 assert_eq!(env.get_linked_param(expr_id), Some(param_name));
306 }
307
308 #[test]
309 fn test_merge() {
310 let mut env1 = TypeEnv::new();
311 let mut env2 = TypeEnv::new();
312
313 let expr1 = ExprId::next();
314 let expr2 = ExprId::next();
315
316 env1.add_constraint(TypeConstraint::new(
317 expr1,
318 c_int_type(),
319 "env1",
320 ));
321
322 env2.add_constraint(TypeConstraint::new(
323 expr2,
324 TypeRepr::CType {
325 specs: CTypeSpecs::Char { signed: None },
326 derived: vec![],
327 source: CTypeSource::Apidoc { raw: "char".to_string() },
328 },
329 "env2",
330 ));
331
332 env1.merge(env2);
333
334 assert_eq!(env1.expr_constraint_count(), 2);
335 assert!(env1.get_expr_constraints(expr1).is_some());
336 assert!(env1.get_expr_constraints(expr2).is_some());
337 }
338
339 #[test]
340 fn test_return_constraints() {
341 let mut env = TypeEnv::new();
342 let expr_id = ExprId::next();
343
344 env.add_return_constraint(TypeConstraint::new(
345 expr_id,
346 apidoc_sv_ptr_type(),
347 "return type from apidoc",
348 ));
349
350 assert_eq!(env.return_constraint_count(), 1);
351 assert_eq!(env.return_constraints[0].source_display(), "apidoc");
352 }
353
354 #[test]
355 fn test_type_repr_source_display() {
356 assert_eq!(c_int_type().source_display(), "c-header");
358 assert_eq!(apidoc_sv_ptr_type().source_display(), "apidoc");
359
360 assert_eq!(rust_c_int_type().source_display(), "rust-bindings");
362
363 let inferred = TypeRepr::Inferred(InferredType::IntLiteral);
365 assert_eq!(inferred.source_display(), "inferred");
366 }
367}