1use std::collections::HashMap;
14use thiserror::Error;
15
16#[derive(Debug, Error)]
18pub enum TypeInferenceError {
19 #[error("Variable not found: {0}")]
20 VariableNotFound(String),
21
22 #[error("Type inference failed: {0}")]
23 InferenceFailed(String),
24
25 #[error("Ambiguous type: {0}")]
26 AmbiguousType(String),
27}
28
29pub type Result<T> = std::result::Result<T, TypeInferenceError>;
30
31#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct TypeInfo {
34 pub type_name: String,
36
37 pub is_mutable: bool,
39
40 pub definition_location: Option<String>,
42}
43
44impl TypeInfo {
45 pub fn new(type_name: impl Into<String>) -> Self {
47 Self { type_name: type_name.into(), is_mutable: false, definition_location: None }
48 }
49
50 pub fn with_mutable(mut self, is_mutable: bool) -> Self {
52 self.is_mutable = is_mutable;
53 self
54 }
55
56 pub fn with_location(mut self, location: impl Into<String>) -> Self {
58 self.definition_location = Some(location.into());
59 self
60 }
61}
62
63#[derive(Debug, Clone)]
85pub struct TypeInference {
86 variables: HashMap<String, TypeInfo>,
88
89 debug: bool,
91}
92
93impl TypeInference {
94 pub fn new() -> Self {
96 Self { variables: HashMap::new(), debug: false }
97 }
98
99 pub fn with_debug(mut self, debug: bool) -> Self {
101 self.debug = debug;
102 self
103 }
104
105 pub fn register_variable(&mut self, name: impl Into<String>, type_info: TypeInfo) {
112 let name = name.into();
113 if self.debug {
114 eprintln!("TypeInference: Registering {} : {}", name, type_info.type_name);
115 }
116 self.variables.insert(name, type_info);
117 }
118
119 pub fn get_variable_type(&self, name: impl AsRef<str>) -> Result<&TypeInfo> {
133 self.variables
134 .get(name.as_ref())
135 .ok_or_else(|| TypeInferenceError::VariableNotFound(name.as_ref().to_string()))
136 }
137
138 pub fn infer_from_code(&self, code: impl AsRef<str>) -> Vec<(String, TypeInfo)> {
153 let code = code.as_ref();
154 let mut types = Vec::new();
155
156 let parsed = match syn::parse_file(code) {
158 Ok(file) => file,
159 Err(_) => {
160 if self.debug {
161 eprintln!("Failed to parse code as complete file, trying as statements");
162 }
163 let wrapped = format!("fn __oxur_wrapper() {{\n{}\n}}", code);
165 match syn::parse_file(&wrapped) {
166 Ok(file) => file,
167 Err(_) => return types, }
169 }
170 };
171
172 self.extract_types_from_file(&parsed, &mut types);
174
175 types
176 }
177
178 fn extract_types_from_file(&self, file: &syn::File, types: &mut Vec<(String, TypeInfo)>) {
180 for item in &file.items {
181 self.extract_types_from_item(item, types);
182 }
183 }
184
185 fn extract_types_from_item(&self, item: &syn::Item, types: &mut Vec<(String, TypeInfo)>) {
187 match item {
188 syn::Item::Fn(func) => {
189 self.extract_types_from_block(&func.block, types);
191 }
192 syn::Item::Static(static_item) => {
193 let name = static_item.ident.to_string();
195 let ty = &static_item.ty;
196 let type_name = quote::quote!(#ty).to_string();
197 types.push((
198 name,
199 TypeInfo::new(type_name).with_mutable(matches!(
200 static_item.mutability,
201 syn::StaticMutability::Mut(_)
202 )),
203 ));
204 }
205 syn::Item::Const(const_item) => {
206 let name = const_item.ident.to_string();
208 let ty = &const_item.ty;
209 let type_name = quote::quote!(#ty).to_string();
210 types.push((name, TypeInfo::new(type_name)));
211 }
212 _ => {
213 }
215 }
216 }
217
218 fn extract_types_from_block(&self, block: &syn::Block, types: &mut Vec<(String, TypeInfo)>) {
220 for stmt in &block.stmts {
221 self.extract_types_from_stmt(stmt, types);
222 }
223 }
224
225 fn extract_types_from_stmt(&self, stmt: &syn::Stmt, types: &mut Vec<(String, TypeInfo)>) {
227 match stmt {
228 syn::Stmt::Local(local) => {
229 self.extract_types_from_local(local, types);
230 }
231 syn::Stmt::Expr(expr, _) => {
232 self.extract_types_from_expr(expr, types);
234 }
235 _ => {}
236 }
237 }
238
239 fn extract_types_from_local(&self, local: &syn::Local, types: &mut Vec<(String, TypeInfo)>) {
241 let (pat, explicit_type) = match &local.pat {
243 syn::Pat::Type(pat_type) => {
244 let ty = &pat_type.ty;
246 let ty_str = quote::quote!(#ty).to_string();
247 (&*pat_type.pat, Some(ty_str))
248 }
249 other_pat => (other_pat, None),
250 };
251
252 if let syn::Pat::Ident(pat_ident) = pat {
254 let name = pat_ident.ident.to_string();
255 let is_mutable = pat_ident.mutability.is_some();
256
257 let type_name = if let Some(ty) = explicit_type {
259 ty
260 } else if let Some(init) = &local.init {
261 self.infer_type_from_expr(&init.expr)
263 } else {
264 "unknown".to_string()
265 };
266
267 types.push((name, TypeInfo::new(type_name).with_mutable(is_mutable)));
268 }
269 }
270
271 fn extract_types_from_expr(&self, expr: &syn::Expr, types: &mut Vec<(String, TypeInfo)>) {
273 match expr {
274 syn::Expr::Block(block_expr) => {
275 self.extract_types_from_block(&block_expr.block, types);
276 }
277 syn::Expr::If(if_expr) => {
278 self.extract_types_from_block(&if_expr.then_branch, types);
279 if let Some((_, else_branch)) = &if_expr.else_branch {
280 self.extract_types_from_expr(else_branch, types);
281 }
282 }
283 syn::Expr::Loop(loop_expr) => {
284 self.extract_types_from_block(&loop_expr.body, types);
285 }
286 syn::Expr::While(while_expr) => {
287 self.extract_types_from_block(&while_expr.body, types);
288 }
289 syn::Expr::ForLoop(for_expr) => {
290 self.extract_types_from_block(&for_expr.body, types);
291 }
292 syn::Expr::Match(match_expr) => {
293 for arm in &match_expr.arms {
294 self.extract_types_from_expr(&arm.body, types);
295 }
296 }
297 _ => {}
298 }
299 }
300
301 fn infer_type_from_expr(&self, expr: &syn::Expr) -> String {
303 match expr {
304 syn::Expr::Lit(lit_expr) => match &lit_expr.lit {
305 syn::Lit::Str(_) => "&str".to_string(),
306 syn::Lit::ByteStr(_) => "&[u8]".to_string(),
307 syn::Lit::Byte(_) => "u8".to_string(),
308 syn::Lit::Char(_) => "char".to_string(),
309 syn::Lit::Int(int_lit) => {
310 let suffix = int_lit.suffix();
312 if suffix.is_empty() {
313 "i32".to_string() } else {
315 suffix.to_string()
316 }
317 }
318 syn::Lit::Float(float_lit) => {
319 let suffix = float_lit.suffix();
320 if suffix.is_empty() {
321 "f64".to_string() } else {
323 suffix.to_string()
324 }
325 }
326 syn::Lit::Bool(_) => "bool".to_string(),
327 _ => "unknown".to_string(),
328 },
329 syn::Expr::Array(_) => "array".to_string(),
330 syn::Expr::Tuple(_) => "tuple".to_string(),
331 syn::Expr::Call(call) => {
332 if let syn::Expr::Path(path) = &*call.func {
334 if let Some(segment) = path.path.segments.last() {
335 let fn_name = segment.ident.to_string();
336 return match fn_name.as_str() {
338 "String" => "String".to_string(),
339 "Vec" => "Vec<_>".to_string(),
340 "HashMap" => "HashMap<_, _>".to_string(),
341 "Box" => "Box<_>".to_string(),
342 _ => "unknown".to_string(),
343 };
344 }
345 }
346 "unknown".to_string()
347 }
348 syn::Expr::MethodCall(method) => {
349 match method.method.to_string().as_str() {
351 "to_string" => "String".to_string(),
352 "to_vec" => "Vec<_>".to_string(),
353 "clone" => "unknown".to_string(),
354 _ => "unknown".to_string(),
355 }
356 }
357 _ => "unknown".to_string(),
358 }
359 }
360
361 pub fn all_variables(&self) -> impl Iterator<Item = (&String, &TypeInfo)> {
363 self.variables.iter()
364 }
365
366 pub fn has_variable(&self, name: impl AsRef<str>) -> bool {
368 self.variables.contains_key(name.as_ref())
369 }
370
371 pub fn remove_variable(&mut self, name: impl AsRef<str>) -> Option<TypeInfo> {
373 self.variables.remove(name.as_ref())
374 }
375
376 pub fn clear(&mut self) {
378 self.variables.clear();
379 }
380
381 pub fn variable_count(&self) -> usize {
383 self.variables.len()
384 }
385}
386
387impl Default for TypeInference {
388 fn default() -> Self {
389 Self::new()
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_type_info_creation() {
399 let info = TypeInfo::new("i32");
400 assert_eq!(info.type_name, "i32");
401 assert!(!info.is_mutable);
402 assert!(info.definition_location.is_none());
403 }
404
405 #[test]
406 fn test_type_info_with_mutable() {
407 let info = TypeInfo::new("String").with_mutable(true);
408 assert_eq!(info.type_name, "String");
409 assert!(info.is_mutable);
410 }
411
412 #[test]
413 fn test_type_info_with_location() {
414 let info = TypeInfo::new("Vec<u8>").with_location("line 42");
415 assert_eq!(info.type_name, "Vec<u8>");
416 assert_eq!(info.definition_location, Some("line 42".to_string()));
417 }
418
419 #[test]
420 fn test_inference_creation() {
421 let inference = TypeInference::new();
422 assert_eq!(inference.variable_count(), 0);
423 assert!(!inference.debug);
424 }
425
426 #[test]
427 fn test_inference_with_debug() {
428 let inference = TypeInference::new().with_debug(true);
429 assert!(inference.debug);
430 }
431
432 #[test]
433 fn test_register_and_get_variable() {
434 let mut inference = TypeInference::new();
435
436 inference.register_variable("x", TypeInfo::new("i32"));
437
438 let type_info = inference.get_variable_type("x").expect("Variable not found");
439 assert_eq!(type_info.type_name, "i32");
440 }
441
442 #[test]
443 fn test_get_nonexistent_variable() {
444 let inference = TypeInference::new();
445
446 let result = inference.get_variable_type("nonexistent");
447 assert!(result.is_err());
448
449 if let Err(TypeInferenceError::VariableNotFound(name)) = result {
450 assert_eq!(name, "nonexistent");
451 }
452 }
453
454 #[test]
455 fn test_register_multiple_variables() {
456 let mut inference = TypeInference::new();
457
458 inference.register_variable("x", TypeInfo::new("i32"));
459 inference.register_variable("y", TypeInfo::new("String"));
460 inference.register_variable("z", TypeInfo::new("bool"));
461
462 assert_eq!(inference.variable_count(), 3);
463 }
464
465 #[test]
466 fn test_has_variable() {
467 let mut inference = TypeInference::new();
468
469 inference.register_variable("x", TypeInfo::new("i32"));
470
471 assert!(inference.has_variable("x"));
472 assert!(!inference.has_variable("y"));
473 }
474
475 #[test]
476 fn test_remove_variable() {
477 let mut inference = TypeInference::new();
478
479 inference.register_variable("x", TypeInfo::new("i32"));
480 assert_eq!(inference.variable_count(), 1);
481
482 let removed = inference.remove_variable("x");
483 assert!(removed.is_some());
484 assert_eq!(inference.variable_count(), 0);
485 }
486
487 #[test]
488 fn test_clear() {
489 let mut inference = TypeInference::new();
490
491 inference.register_variable("x", TypeInfo::new("i32"));
492 inference.register_variable("y", TypeInfo::new("String"));
493
494 assert_eq!(inference.variable_count(), 2);
495
496 inference.clear();
497 assert_eq!(inference.variable_count(), 0);
498 }
499
500 #[test]
501 fn test_all_variables() {
502 let mut inference = TypeInference::new();
503
504 inference.register_variable("x", TypeInfo::new("i32"));
505 inference.register_variable("y", TypeInfo::new("String"));
506
507 let all: Vec<_> = inference.all_variables().collect();
508 assert_eq!(all.len(), 2);
509 }
510
511 #[test]
512 fn test_infer_from_code_explicit_types() {
513 let inference = TypeInference::new();
514
515 let code = r#"
516fn main() {
517 let x: i32 = 42;
518 let y: String = "hello".to_string();
519}
520"#;
521 let inferred = inference.infer_from_code(code);
522
523 assert_eq!(inferred.len(), 2);
524 assert_eq!(inferred[0].0, "x");
525 assert_eq!(inferred[0].1.type_name, "i32");
526 assert!(!inferred[0].1.is_mutable);
527
528 assert_eq!(inferred[1].0, "y");
529 assert_eq!(inferred[1].1.type_name, "String");
530 assert!(!inferred[1].1.is_mutable);
531 }
532
533 #[test]
534 fn test_infer_from_code_inferred_literal() {
535 let inference = TypeInference::new();
536
537 let code = r#"
538fn test() {
539 let x = 42;
540 let y = 3.14;
541 let z = true;
542 let s = "hello";
543}
544"#;
545 let inferred = inference.infer_from_code(code);
546
547 assert_eq!(inferred.len(), 4);
548 assert_eq!(inferred[0].1.type_name, "i32"); assert_eq!(inferred[1].1.type_name, "f64"); assert_eq!(inferred[2].1.type_name, "bool");
551 assert_eq!(inferred[3].1.type_name, "&str");
552 }
553
554 #[test]
555 fn test_infer_from_code_mutable() {
556 let inference = TypeInference::new();
557
558 let code = r#"
559fn test() {
560 let mut x: i32 = 42;
561 let y: i32 = 10;
562}
563"#;
564 let inferred = inference.infer_from_code(code);
565
566 assert_eq!(inferred.len(), 2);
567 assert!(inferred[0].1.is_mutable);
568 assert!(!inferred[1].1.is_mutable);
569 }
570
571 #[test]
572 fn test_infer_from_code_typed_suffix() {
573 let inference = TypeInference::new();
574
575 let code = r#"
576fn test() {
577 let x = 42u64;
578 let y = 3.14f32;
579}
580"#;
581 let inferred = inference.infer_from_code(code);
582
583 assert_eq!(inferred.len(), 2);
584 assert_eq!(inferred[0].1.type_name, "u64");
585 assert_eq!(inferred[1].1.type_name, "f32");
586 }
587
588 #[test]
589 fn test_infer_from_code_method_calls() {
590 let inference = TypeInference::new();
591
592 let code = r#"
593fn test() {
594 let s = "hello".to_string();
595 let v = vec![1, 2, 3].to_vec();
596}
597"#;
598 let inferred = inference.infer_from_code(code);
599
600 assert_eq!(inferred.len(), 2);
601 assert_eq!(inferred[0].1.type_name, "String");
602 assert_eq!(inferred[1].1.type_name, "Vec<_>");
603 }
604
605 #[test]
606 fn test_infer_from_code_nested_blocks() {
607 let inference = TypeInference::new();
608
609 let code = r#"
610fn test() {
611 let x = 1;
612 if true {
613 let y = 2;
614 }
615 for i in 0..10 {
616 let z = 3;
617 }
618}
619"#;
620 let inferred = inference.infer_from_code(code);
621
622 assert_eq!(inferred.len(), 3);
623 assert_eq!(inferred[0].0, "x");
624 assert_eq!(inferred[1].0, "y");
625 assert_eq!(inferred[2].0, "z");
626 }
627
628 #[test]
629 fn test_infer_from_code_invalid_syntax() {
630 let inference = TypeInference::new();
631
632 let code = "this is not valid rust code {{{";
633 let inferred = inference.infer_from_code(code);
634
635 assert_eq!(inferred.len(), 0);
637 }
638
639 #[test]
640 fn test_infer_from_code_wrapped_statements() {
641 let inference = TypeInference::new();
642
643 let code = "let x: i32 = 42;\nlet y: bool = true;";
645 let inferred = inference.infer_from_code(code);
646
647 assert!(!inferred.is_empty()); }
650
651 #[test]
652 fn test_infer_from_code_constants() {
653 let inference = TypeInference::new();
654
655 let code = r#"
656const MAX: usize = 100;
657static mut COUNTER: i32 = 0;
658"#;
659 let inferred = inference.infer_from_code(code);
660
661 assert!(!inferred.is_empty());
662 }
664
665 #[test]
666 fn test_default() {
667 let inference1 = TypeInference::default();
668 let inference2 = TypeInference::new();
669
670 assert_eq!(inference1.variable_count(), inference2.variable_count());
671 assert_eq!(inference1.debug, inference2.debug);
672 }
673
674 #[test]
675 fn test_overwrite_variable_type() {
676 let mut inference = TypeInference::new();
677
678 inference.register_variable("x", TypeInfo::new("i32"));
679 assert_eq!(inference.get_variable_type("x").unwrap().type_name, "i32");
680
681 inference.register_variable("x", TypeInfo::new("String"));
683 assert_eq!(inference.get_variable_type("x").unwrap().type_name, "String");
684 }
685}