colgrep 1.3.1

Semantic code search powered by ColBERT
Documentation
//! Tests for Python code extraction.

use super::common::*;
use crate::embed::build_embedding_text;
use crate::parser::Language;

#[test]
fn test_basic_function() {
    let source = r#"def greet(name: str) -> str:
    """Say hello to someone."""
    return f"Hello, {name}!""#;
    let units = parse(source, Language::Python, "test.py");
    let func = get_unit_by_name(&units, "greet").unwrap();
    let text = build_embedding_text(func);

    let expected = r#"Function: greet
Signature: def greet(name: str) -> str:
Description: """Say hello to someone.
Parameters: name
Returns: str
File: test test.py
Code:
def greet(name: str) -> str:
    """Say hello to someone."""
    return f"Hello, {name}!""#;
    assert_eq!(text, expected);
}

#[test]
fn test_function_with_imports() {
    let source = r#"import json
from urllib.parse import urlencode

def fetch_data(url: str) -> dict:
    """Fetch JSON data from URL."""
    return json.loads("{}")"#;
    let units = parse(source, Language::Python, "test.py");
    let func = get_unit_by_name(&units, "fetch_data").unwrap();
    let text = build_embedding_text(func);

    let expected = r#"Function: fetch_data
Signature: def fetch_data(url: str) -> dict:
Description: """Fetch JSON data from URL.
Parameters: url
Returns: dict
Calls: loads
Uses: json
File: test test.py
Code:
def fetch_data(url: str) -> dict:
    """Fetch JSON data from URL."""
    return json.loads("{}")"#;
    assert_eq!(text, expected);
}

#[test]
fn test_class_definition() {
    let source = r#"class Calculator:
    """A simple calculator class."""

    def __init__(self, value: int = 0):
        self.value = value

    def add(self, x: int) -> int:
        """Add x to the current value."""
        self.value += x
        return self.value"#;
    let units = parse(source, Language::Python, "test.py");

    // Class should be extracted as a single chunk containing all methods
    let class = get_unit_by_name(&units, "Calculator").unwrap();
    let class_text = build_embedding_text(class);
    let expected_class = r#"Class: Calculator
Signature: class Calculator:
Description: """A simple calculator class.
Variables: self.value
File: test test.py
Code:
class Calculator:
    """A simple calculator class."""

    def __init__(self, value: int = 0):
        self.value = value

    def add(self, x: int) -> int:
        """Add x to the current value."""
        self.value += x
        return self.value"#;
    assert_eq!(class_text, expected_class);

    // Methods are extracted as separate units (alongside the class).
    assert!(
        get_unit_by_name(&units, "__init__").is_some(),
        "Methods are extracted as separate units alongside their parent classes"
    );
    assert!(
        get_unit_by_name(&units, "add").is_some(),
        "Methods are extracted as separate units alongside their parent classes"
    );

    // Verify class code contains all methods
    assert!(class.code.contains("__init__"));
    assert!(class.code.contains("add"));
}

#[test]
fn test_decorated_function() {
    let source = r#"@staticmethod
@decorator_with_args(arg=1)
def decorated_func():
    """A decorated function."""
    pass"#;
    let units = parse(source, Language::Python, "test.py");
    let func = get_unit_by_name(&units, "decorated_func").unwrap();
    let text = build_embedding_text(func);

    let expected = r#"Function: decorated_func
Signature: def decorated_func():
Description: """A decorated function.
File: test test.py
Code:
@staticmethod
@decorator_with_args(arg=1)
def decorated_func():
    """A decorated function."""
    pass"#;
    assert_eq!(text, expected);
}

#[test]
fn test_async_function() {
    let source = r#"async def fetch_async(url: str) -> bytes:
    """Fetch data asynchronously."""
    return b"data""#;
    let units = parse(source, Language::Python, "test.py");
    let func = get_unit_by_name(&units, "fetch_async").unwrap();
    let text = build_embedding_text(func);

    let expected = r#"Function: fetch_async
Signature: async def fetch_async(url: str) -> bytes:
Description: """Fetch data asynchronously.
Parameters: url
Returns: bytes
File: test test.py
Code:
async def fetch_async(url: str) -> bytes:
    """Fetch data asynchronously."""
    return b"data""#;
    assert_eq!(text, expected);
}

#[test]
fn test_function_with_args_kwargs() {
    let source = r#"def variadic_func(*args, **kwargs):
    """Function with variadic arguments."""
    return args, kwargs"#;
    let units = parse(source, Language::Python, "test.py");
    let func = get_unit_by_name(&units, "variadic_func").unwrap();
    let text = build_embedding_text(func);

    let expected = r#"Function: variadic_func
Signature: def variadic_func(*args, **kwargs):
Description: """Function with variadic arguments.
Parameters: args, kwargs
File: test test.py
Code:
def variadic_func(*args, **kwargs):
    """Function with variadic arguments."""
    return args, kwargs"#;
    assert_eq!(text, expected);
}

#[test]
fn test_multiline_docstring() {
    let source = r#"def complex_function(x: int, y: int) -> int:
    """
    This is a complex function that does many things.

    It processes x and y in a special way.

    Args:
        x: First number
        y: Second number

    Returns:
        The processed result
    """
    return x + y"#;
    let units = parse(source, Language::Python, "test.py");
    let func = get_unit_by_name(&units, "complex_function").unwrap();
    let text = build_embedding_text(func);

    let expected = r##"Function: complex_function
Signature: def complex_function(x: int, y: int) -> int:
Description: """
    This is a complex function that does many things.

    It processes x and y in a special way.

    Args:
        x: First number
        y: Second number

    Returns:
        The processed result
Parameters: x, y
Returns: int
File: test test.py
Code:
def complex_function(x: int, y: int) -> int:
    """
    This is a complex function that does many things.

    It processes x and y in a special way.

    Args:
        x: First number
        y: Second number

    Returns:
        The processed result
    """
    return x + y"##;
    assert_eq!(text, expected);
}

#[test]
fn test_constants() {
    let source = r#"MAX_SIZE = 1024
DEFAULT_NAME = "test"
regular_var = "not a constant"

def process():
    pass"#;
    let units = parse(source, Language::Python, "test.py");

    let max_size = get_unit_by_name(&units, "MAX_SIZE").unwrap();
    let max_text = build_embedding_text(max_size);
    let expected_max = r#"MAX_SIZE = 1024"#;
    assert_eq!(max_text, expected_max);

    let default_name = get_unit_by_name(&units, "DEFAULT_NAME").unwrap();
    let default_text = build_embedding_text(default_name);
    let expected_default = r#"DEFAULT_NAME = "test""#;
    assert_eq!(default_text, expected_default);

    // regular_var should not be extracted as it's lowercase
    assert!(get_unit_by_name(&units, "regular_var").is_none());
}

#[test]
fn test_nested_class() {
    let source = r#"class Outer:
    """Outer class."""

    class Inner:
        """Inner class."""

        def inner_method(self):
            pass"#;
    let units = parse(source, Language::Python, "test.py");

    let outer = get_unit_by_name(&units, "Outer").unwrap();
    let outer_text = build_embedding_text(outer);
    let expected_outer = r#"Class: Outer
Signature: class Outer:
Description: """Outer class.
File: test test.py
Code:
class Outer:
    """Outer class."""

    class Inner:
        """Inner class."""

        def inner_method(self):
            pass"#;
    assert_eq!(outer_text, expected_outer);
}

#[test]
fn test_class_inheritance() {
    let source = r#"class Animal:
    """Base animal class."""
    def speak(self):
        pass

class Dog(Animal):
    """A dog that barks."""
    def speak(self):
        return "Woof!"

class Cat(Animal):
    """A cat that meows."""
    def speak(self):
        return "Meow!"
"#;
    let units = parse(source, Language::Python, "test.py");

    let animal = get_unit_by_name(&units, "Animal").unwrap();
    let animal_text = build_embedding_text(animal);
    // Animal has no parent
    assert!(!animal_text.contains("Extends:"));

    let dog = get_unit_by_name(&units, "Dog").unwrap();
    let dog_text = build_embedding_text(dog);
    let expected_dog = r#"Class: Dog
Signature: class Dog(Animal):
Extends: Animal
Description: """A dog that barks.
File: test test.py
Code:
class Dog(Animal):
    """A dog that barks."""
    def speak(self):
        return "Woof!""#;
    assert_eq!(dog_text, expected_dog);

    let cat = get_unit_by_name(&units, "Cat").unwrap();
    let cat_text = build_embedding_text(cat);
    let expected_cat = r#"Class: Cat
Signature: class Cat(Animal):
Extends: Animal
Description: """A cat that meows.
File: test test.py
Code:
class Cat(Animal):
    """A cat that meows."""
    def speak(self):
        return "Meow!""#;
    assert_eq!(cat_text, expected_cat);
}

#[test]
fn test_lambda_not_extracted_as_function() {
    let source = r#"square = lambda x: x ** 2

def real_function():
    return square(5)"#;
    let units = parse(source, Language::Python, "test.py");

    let func = get_unit_by_name(&units, "real_function").unwrap();
    let text = build_embedding_text(func);
    let expected = r#"Function: real_function
Signature: def real_function():
Calls: square
File: test test.py
Code:
def real_function():
    return square(5)"#;
    assert_eq!(text, expected);

    // Lambda should not be extracted as a separate function
    assert!(get_unit_by_name(&units, "lambda").is_none());
    assert!(get_unit_by_name(&units, "<lambda>").is_none());
}