oxur_repl/
type_inference.rs

1//! Type inference for REPL variables
2//!
3//! Tracks types of REPL variables across evaluations.
4//!
5//! # Phase 3 Implementation
6//!
7//! This is a stub implementation that returns basic type information.
8//! Full implementation (Phase 4+) will integrate with rust-analyzer
9//! for proper type inference.
10//!
11//! Based on ODD-0026 Section 2.1 and ODD-0038 Decision 6
12
13use std::collections::HashMap;
14use thiserror::Error;
15
16/// Type inference errors
17#[derive(Debug, Error)]
18pub enum TypeInferenceError {
19    #[error("Variable not found: {0}")]
20    VariableNotFound(String),
21
22    #[error("Type inference failed: {0}")]
23    InferenceFailed(String),
24
25    #[error("Ambiguous type: {0}")]
26    AmbiguousType(String),
27}
28
29pub type Result<T> = std::result::Result<T, TypeInferenceError>;
30
31/// Type information for a variable
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct TypeInfo {
34    /// Rust type as a string (e.g., "i32", "String", "Vec<u8>")
35    pub type_name: String,
36
37    /// Whether the variable is mutable
38    pub is_mutable: bool,
39
40    /// Source location where variable was defined
41    pub definition_location: Option<String>,
42}
43
44impl TypeInfo {
45    /// Create new type info
46    pub fn new(type_name: impl Into<String>) -> Self {
47        Self { type_name: type_name.into(), is_mutable: false, definition_location: None }
48    }
49
50    /// Set mutability
51    pub fn with_mutable(mut self, is_mutable: bool) -> Self {
52        self.is_mutable = is_mutable;
53        self
54    }
55
56    /// Set definition location
57    pub fn with_location(mut self, location: impl Into<String>) -> Self {
58        self.definition_location = Some(location.into());
59        self
60    }
61}
62
63/// Type inference engine
64///
65/// # Phase 3 Stub Implementation
66///
67/// Currently provides basic type tracking without actual inference.
68/// Variables must be explicitly registered with their types.
69///
70/// # Examples
71///
72/// ```
73/// use oxur_repl::type_inference::{TypeInference, TypeInfo};
74///
75/// let mut inference = TypeInference::new();
76///
77/// // Register a variable type
78/// inference.register_variable("x", TypeInfo::new("i32"));
79///
80/// // Look up variable type
81/// let type_info = inference.get_variable_type("x").unwrap();
82/// assert_eq!(type_info.type_name, "i32");
83/// ```
84#[derive(Debug, Clone)]
85pub struct TypeInference {
86    /// Map from variable name to type information
87    variables: HashMap<String, TypeInfo>,
88
89    /// Whether to enable debug logging
90    debug: bool,
91}
92
93impl TypeInference {
94    /// Create a new type inference engine
95    pub fn new() -> Self {
96        Self { variables: HashMap::new(), debug: false }
97    }
98
99    /// Enable debug mode
100    pub fn with_debug(mut self, debug: bool) -> Self {
101        self.debug = debug;
102        self
103    }
104
105    /// Register a variable with its type
106    ///
107    /// # Arguments
108    ///
109    /// * `name` - Variable name
110    /// * `type_info` - Type information
111    pub fn register_variable(&mut self, name: impl Into<String>, type_info: TypeInfo) {
112        let name = name.into();
113        if self.debug {
114            eprintln!("TypeInference: Registering {} : {}", name, type_info.type_name);
115        }
116        self.variables.insert(name, type_info);
117    }
118
119    /// Get the type of a variable
120    ///
121    /// # Arguments
122    ///
123    /// * `name` - Variable name
124    ///
125    /// # Returns
126    ///
127    /// Type information for the variable
128    ///
129    /// # Errors
130    ///
131    /// Returns error if variable is not found
132    pub fn get_variable_type(&self, name: impl AsRef<str>) -> Result<&TypeInfo> {
133        self.variables
134            .get(name.as_ref())
135            .ok_or_else(|| TypeInferenceError::VariableNotFound(name.as_ref().to_string()))
136    }
137
138    /// Infer types from Rust code
139    ///
140    /// # Phase 3 Stub
141    ///
142    /// Currently returns empty list. Full implementation will use rust-analyzer
143    /// to extract type information from let bindings.
144    ///
145    /// # Arguments
146    ///
147    /// * `code` - Rust source code to analyze
148    ///
149    /// # Returns
150    ///
151    /// List of (variable_name, type_info) pairs
152    pub fn infer_from_code(&self, code: impl AsRef<str>) -> Vec<(String, TypeInfo)> {
153        let code = code.as_ref();
154        let mut types = Vec::new();
155
156        // Parse the code as a Rust file
157        let parsed = match syn::parse_file(code) {
158            Ok(file) => file,
159            Err(_) => {
160                if self.debug {
161                    eprintln!("Failed to parse code as complete file, trying as statements");
162                }
163                // Try parsing as individual statements wrapped in a function
164                let wrapped = format!("fn __oxur_wrapper() {{\n{}\n}}", code);
165                match syn::parse_file(&wrapped) {
166                    Ok(file) => file,
167                    Err(_) => return types, // Can't parse, return empty
168                }
169            }
170        };
171
172        // Extract type information from the parsed file
173        self.extract_types_from_file(&parsed, &mut types);
174
175        types
176    }
177
178    /// Extract type information from a syn::File
179    fn extract_types_from_file(&self, file: &syn::File, types: &mut Vec<(String, TypeInfo)>) {
180        for item in &file.items {
181            self.extract_types_from_item(item, types);
182        }
183    }
184
185    /// Extract type information from a syn::Item
186    fn extract_types_from_item(&self, item: &syn::Item, types: &mut Vec<(String, TypeInfo)>) {
187        match item {
188            syn::Item::Fn(func) => {
189                // Extract from function body
190                self.extract_types_from_block(&func.block, types);
191            }
192            syn::Item::Static(static_item) => {
193                // Static variables
194                let name = static_item.ident.to_string();
195                let ty = &static_item.ty;
196                let type_name = quote::quote!(#ty).to_string();
197                types.push((
198                    name,
199                    TypeInfo::new(type_name).with_mutable(matches!(
200                        static_item.mutability,
201                        syn::StaticMutability::Mut(_)
202                    )),
203                ));
204            }
205            syn::Item::Const(const_item) => {
206                // Constants
207                let name = const_item.ident.to_string();
208                let ty = &const_item.ty;
209                let type_name = quote::quote!(#ty).to_string();
210                types.push((name, TypeInfo::new(type_name)));
211            }
212            _ => {
213                // Other items don't contain local variables
214            }
215        }
216    }
217
218    /// Extract type information from a block
219    fn extract_types_from_block(&self, block: &syn::Block, types: &mut Vec<(String, TypeInfo)>) {
220        for stmt in &block.stmts {
221            self.extract_types_from_stmt(stmt, types);
222        }
223    }
224
225    /// Extract type information from a statement
226    fn extract_types_from_stmt(&self, stmt: &syn::Stmt, types: &mut Vec<(String, TypeInfo)>) {
227        match stmt {
228            syn::Stmt::Local(local) => {
229                self.extract_types_from_local(local, types);
230            }
231            syn::Stmt::Expr(expr, _) => {
232                // Check for nested blocks in expressions (if, loop, etc.)
233                self.extract_types_from_expr(expr, types);
234            }
235            _ => {}
236        }
237    }
238
239    /// Extract type information from a local variable declaration
240    fn extract_types_from_local(&self, local: &syn::Local, types: &mut Vec<(String, TypeInfo)>) {
241        // Check if pattern has explicit type annotation (let x: T = ...)
242        let (pat, explicit_type) = match &local.pat {
243            syn::Pat::Type(pat_type) => {
244                // Pattern with type annotation
245                let ty = &pat_type.ty;
246                let ty_str = quote::quote!(#ty).to_string();
247                (&*pat_type.pat, Some(ty_str))
248            }
249            other_pat => (other_pat, None),
250        };
251
252        // Extract variable name from pattern
253        if let syn::Pat::Ident(pat_ident) = pat {
254            let name = pat_ident.ident.to_string();
255            let is_mutable = pat_ident.mutability.is_some();
256
257            // Determine type: explicit annotation > inferred from init > unknown
258            let type_name = if let Some(ty) = explicit_type {
259                ty
260            } else if let Some(init) = &local.init {
261                // Try to infer from initialization expression
262                self.infer_type_from_expr(&init.expr)
263            } else {
264                "unknown".to_string()
265            };
266
267            types.push((name, TypeInfo::new(type_name).with_mutable(is_mutable)));
268        }
269    }
270
271    /// Extract types from expressions (for nested blocks)
272    fn extract_types_from_expr(&self, expr: &syn::Expr, types: &mut Vec<(String, TypeInfo)>) {
273        match expr {
274            syn::Expr::Block(block_expr) => {
275                self.extract_types_from_block(&block_expr.block, types);
276            }
277            syn::Expr::If(if_expr) => {
278                self.extract_types_from_block(&if_expr.then_branch, types);
279                if let Some((_, else_branch)) = &if_expr.else_branch {
280                    self.extract_types_from_expr(else_branch, types);
281                }
282            }
283            syn::Expr::Loop(loop_expr) => {
284                self.extract_types_from_block(&loop_expr.body, types);
285            }
286            syn::Expr::While(while_expr) => {
287                self.extract_types_from_block(&while_expr.body, types);
288            }
289            syn::Expr::ForLoop(for_expr) => {
290                self.extract_types_from_block(&for_expr.body, types);
291            }
292            syn::Expr::Match(match_expr) => {
293                for arm in &match_expr.arms {
294                    self.extract_types_from_expr(&arm.body, types);
295                }
296            }
297            _ => {}
298        }
299    }
300
301    /// Attempt to infer type from an expression
302    fn infer_type_from_expr(&self, expr: &syn::Expr) -> String {
303        match expr {
304            syn::Expr::Lit(lit_expr) => match &lit_expr.lit {
305                syn::Lit::Str(_) => "&str".to_string(),
306                syn::Lit::ByteStr(_) => "&[u8]".to_string(),
307                syn::Lit::Byte(_) => "u8".to_string(),
308                syn::Lit::Char(_) => "char".to_string(),
309                syn::Lit::Int(int_lit) => {
310                    // Check suffix
311                    let suffix = int_lit.suffix();
312                    if suffix.is_empty() {
313                        "i32".to_string() // Default integer type
314                    } else {
315                        suffix.to_string()
316                    }
317                }
318                syn::Lit::Float(float_lit) => {
319                    let suffix = float_lit.suffix();
320                    if suffix.is_empty() {
321                        "f64".to_string() // Default float type
322                    } else {
323                        suffix.to_string()
324                    }
325                }
326                syn::Lit::Bool(_) => "bool".to_string(),
327                _ => "unknown".to_string(),
328            },
329            syn::Expr::Array(_) => "array".to_string(),
330            syn::Expr::Tuple(_) => "tuple".to_string(),
331            syn::Expr::Call(call) => {
332                // Try to get type from function name
333                if let syn::Expr::Path(path) = &*call.func {
334                    if let Some(segment) = path.path.segments.last() {
335                        let fn_name = segment.ident.to_string();
336                        // Common constructors
337                        return match fn_name.as_str() {
338                            "String" => "String".to_string(),
339                            "Vec" => "Vec<_>".to_string(),
340                            "HashMap" => "HashMap<_, _>".to_string(),
341                            "Box" => "Box<_>".to_string(),
342                            _ => "unknown".to_string(),
343                        };
344                    }
345                }
346                "unknown".to_string()
347            }
348            syn::Expr::MethodCall(method) => {
349                // Common method patterns
350                match method.method.to_string().as_str() {
351                    "to_string" => "String".to_string(),
352                    "to_vec" => "Vec<_>".to_string(),
353                    "clone" => "unknown".to_string(),
354                    _ => "unknown".to_string(),
355                }
356            }
357            _ => "unknown".to_string(),
358        }
359    }
360
361    /// Get all tracked variables
362    pub fn all_variables(&self) -> impl Iterator<Item = (&String, &TypeInfo)> {
363        self.variables.iter()
364    }
365
366    /// Check if a variable is tracked
367    pub fn has_variable(&self, name: impl AsRef<str>) -> bool {
368        self.variables.contains_key(name.as_ref())
369    }
370
371    /// Remove a variable from tracking
372    pub fn remove_variable(&mut self, name: impl AsRef<str>) -> Option<TypeInfo> {
373        self.variables.remove(name.as_ref())
374    }
375
376    /// Clear all tracked variables
377    pub fn clear(&mut self) {
378        self.variables.clear();
379    }
380
381    /// Get count of tracked variables
382    pub fn variable_count(&self) -> usize {
383        self.variables.len()
384    }
385}
386
387impl Default for TypeInference {
388    fn default() -> Self {
389        Self::new()
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_type_info_creation() {
399        let info = TypeInfo::new("i32");
400        assert_eq!(info.type_name, "i32");
401        assert!(!info.is_mutable);
402        assert!(info.definition_location.is_none());
403    }
404
405    #[test]
406    fn test_type_info_with_mutable() {
407        let info = TypeInfo::new("String").with_mutable(true);
408        assert_eq!(info.type_name, "String");
409        assert!(info.is_mutable);
410    }
411
412    #[test]
413    fn test_type_info_with_location() {
414        let info = TypeInfo::new("Vec<u8>").with_location("line 42");
415        assert_eq!(info.type_name, "Vec<u8>");
416        assert_eq!(info.definition_location, Some("line 42".to_string()));
417    }
418
419    #[test]
420    fn test_inference_creation() {
421        let inference = TypeInference::new();
422        assert_eq!(inference.variable_count(), 0);
423        assert!(!inference.debug);
424    }
425
426    #[test]
427    fn test_inference_with_debug() {
428        let inference = TypeInference::new().with_debug(true);
429        assert!(inference.debug);
430    }
431
432    #[test]
433    fn test_register_and_get_variable() {
434        let mut inference = TypeInference::new();
435
436        inference.register_variable("x", TypeInfo::new("i32"));
437
438        let type_info = inference.get_variable_type("x").expect("Variable not found");
439        assert_eq!(type_info.type_name, "i32");
440    }
441
442    #[test]
443    fn test_get_nonexistent_variable() {
444        let inference = TypeInference::new();
445
446        let result = inference.get_variable_type("nonexistent");
447        assert!(result.is_err());
448
449        if let Err(TypeInferenceError::VariableNotFound(name)) = result {
450            assert_eq!(name, "nonexistent");
451        }
452    }
453
454    #[test]
455    fn test_register_multiple_variables() {
456        let mut inference = TypeInference::new();
457
458        inference.register_variable("x", TypeInfo::new("i32"));
459        inference.register_variable("y", TypeInfo::new("String"));
460        inference.register_variable("z", TypeInfo::new("bool"));
461
462        assert_eq!(inference.variable_count(), 3);
463    }
464
465    #[test]
466    fn test_has_variable() {
467        let mut inference = TypeInference::new();
468
469        inference.register_variable("x", TypeInfo::new("i32"));
470
471        assert!(inference.has_variable("x"));
472        assert!(!inference.has_variable("y"));
473    }
474
475    #[test]
476    fn test_remove_variable() {
477        let mut inference = TypeInference::new();
478
479        inference.register_variable("x", TypeInfo::new("i32"));
480        assert_eq!(inference.variable_count(), 1);
481
482        let removed = inference.remove_variable("x");
483        assert!(removed.is_some());
484        assert_eq!(inference.variable_count(), 0);
485    }
486
487    #[test]
488    fn test_clear() {
489        let mut inference = TypeInference::new();
490
491        inference.register_variable("x", TypeInfo::new("i32"));
492        inference.register_variable("y", TypeInfo::new("String"));
493
494        assert_eq!(inference.variable_count(), 2);
495
496        inference.clear();
497        assert_eq!(inference.variable_count(), 0);
498    }
499
500    #[test]
501    fn test_all_variables() {
502        let mut inference = TypeInference::new();
503
504        inference.register_variable("x", TypeInfo::new("i32"));
505        inference.register_variable("y", TypeInfo::new("String"));
506
507        let all: Vec<_> = inference.all_variables().collect();
508        assert_eq!(all.len(), 2);
509    }
510
511    #[test]
512    fn test_infer_from_code_explicit_types() {
513        let inference = TypeInference::new();
514
515        let code = r#"
516fn main() {
517    let x: i32 = 42;
518    let y: String = "hello".to_string();
519}
520"#;
521        let inferred = inference.infer_from_code(code);
522
523        assert_eq!(inferred.len(), 2);
524        assert_eq!(inferred[0].0, "x");
525        assert_eq!(inferred[0].1.type_name, "i32");
526        assert!(!inferred[0].1.is_mutable);
527
528        assert_eq!(inferred[1].0, "y");
529        assert_eq!(inferred[1].1.type_name, "String");
530        assert!(!inferred[1].1.is_mutable);
531    }
532
533    #[test]
534    fn test_infer_from_code_inferred_literal() {
535        let inference = TypeInference::new();
536
537        let code = r#"
538fn test() {
539    let x = 42;
540    let y = 3.14;
541    let z = true;
542    let s = "hello";
543}
544"#;
545        let inferred = inference.infer_from_code(code);
546
547        assert_eq!(inferred.len(), 4);
548        assert_eq!(inferred[0].1.type_name, "i32"); // Default int
549        assert_eq!(inferred[1].1.type_name, "f64"); // Default float
550        assert_eq!(inferred[2].1.type_name, "bool");
551        assert_eq!(inferred[3].1.type_name, "&str");
552    }
553
554    #[test]
555    fn test_infer_from_code_mutable() {
556        let inference = TypeInference::new();
557
558        let code = r#"
559fn test() {
560    let mut x: i32 = 42;
561    let y: i32 = 10;
562}
563"#;
564        let inferred = inference.infer_from_code(code);
565
566        assert_eq!(inferred.len(), 2);
567        assert!(inferred[0].1.is_mutable);
568        assert!(!inferred[1].1.is_mutable);
569    }
570
571    #[test]
572    fn test_infer_from_code_typed_suffix() {
573        let inference = TypeInference::new();
574
575        let code = r#"
576fn test() {
577    let x = 42u64;
578    let y = 3.14f32;
579}
580"#;
581        let inferred = inference.infer_from_code(code);
582
583        assert_eq!(inferred.len(), 2);
584        assert_eq!(inferred[0].1.type_name, "u64");
585        assert_eq!(inferred[1].1.type_name, "f32");
586    }
587
588    #[test]
589    fn test_infer_from_code_method_calls() {
590        let inference = TypeInference::new();
591
592        let code = r#"
593fn test() {
594    let s = "hello".to_string();
595    let v = vec![1, 2, 3].to_vec();
596}
597"#;
598        let inferred = inference.infer_from_code(code);
599
600        assert_eq!(inferred.len(), 2);
601        assert_eq!(inferred[0].1.type_name, "String");
602        assert_eq!(inferred[1].1.type_name, "Vec<_>");
603    }
604
605    #[test]
606    fn test_infer_from_code_nested_blocks() {
607        let inference = TypeInference::new();
608
609        let code = r#"
610fn test() {
611    let x = 1;
612    if true {
613        let y = 2;
614    }
615    for i in 0..10 {
616        let z = 3;
617    }
618}
619"#;
620        let inferred = inference.infer_from_code(code);
621
622        assert_eq!(inferred.len(), 3);
623        assert_eq!(inferred[0].0, "x");
624        assert_eq!(inferred[1].0, "y");
625        assert_eq!(inferred[2].0, "z");
626    }
627
628    #[test]
629    fn test_infer_from_code_invalid_syntax() {
630        let inference = TypeInference::new();
631
632        let code = "this is not valid rust code {{{";
633        let inferred = inference.infer_from_code(code);
634
635        // Should return empty on parse failure
636        assert_eq!(inferred.len(), 0);
637    }
638
639    #[test]
640    fn test_infer_from_code_wrapped_statements() {
641        let inference = TypeInference::new();
642
643        // Code without function wrapper (should still parse)
644        let code = "let x: i32 = 42;\nlet y: bool = true;";
645        let inferred = inference.infer_from_code(code);
646
647        // Should successfully parse by wrapping in function
648        assert!(!inferred.is_empty()); // May find variables
649    }
650
651    #[test]
652    fn test_infer_from_code_constants() {
653        let inference = TypeInference::new();
654
655        let code = r#"
656const MAX: usize = 100;
657static mut COUNTER: i32 = 0;
658"#;
659        let inferred = inference.infer_from_code(code);
660
661        assert!(!inferred.is_empty());
662        // Should find at least the constant or static
663    }
664
665    #[test]
666    fn test_default() {
667        let inference1 = TypeInference::default();
668        let inference2 = TypeInference::new();
669
670        assert_eq!(inference1.variable_count(), inference2.variable_count());
671        assert_eq!(inference1.debug, inference2.debug);
672    }
673
674    #[test]
675    fn test_overwrite_variable_type() {
676        let mut inference = TypeInference::new();
677
678        inference.register_variable("x", TypeInfo::new("i32"));
679        assert_eq!(inference.get_variable_type("x").unwrap().type_name, "i32");
680
681        // Overwrite with different type
682        inference.register_variable("x", TypeInfo::new("String"));
683        assert_eq!(inference.get_variable_type("x").unwrap().type_name, "String");
684    }
685}