use std::path::Path;
use crate::parser::{CodeUnit, UnitType};
const MAX_EMBEDDING_TEXT_CHARS: usize = 8 * 1024;
const TRUNCATION_MARKER: &str = "\n[...truncated...]\n";
fn shorten_path(path: &Path) -> String {
let components: Vec<_> = path.components().collect();
let len = components.len();
let start = len.saturating_sub(4);
let shortened: std::path::PathBuf = components[start..].iter().collect();
shortened.display().to_string()
}
fn normalize_path_for_embedding(path_str: &str) -> String {
let original_filename = path_str.rsplit(['/', '\\']).next().unwrap_or(path_str);
let path_without_ext = if let Some(dot_pos) = path_str.rfind('.') {
&path_str[..dot_pos]
} else {
path_str
};
let mut result = String::with_capacity(path_without_ext.len() * 2);
let chars: Vec<char> = path_without_ext.chars().collect();
for (i, &c) in chars.iter().enumerate() {
match c {
'/' | '\\' => {
if !result.ends_with(' ') && !result.is_empty() {
result.push(' ');
}
}
'_' | '-' | '.' => {
if !result.ends_with(' ') {
result.push(' ');
}
}
c if c.is_uppercase() => {
if i > 0 {
let prev = chars[i - 1];
if prev.is_lowercase() {
result.push(' ');
}
}
result.push(c);
}
_ => {
result.push(c);
}
}
}
let normalized = result
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
.to_lowercase();
format!("{} {}", normalized, original_filename)
}
fn count_chars(s: &str) -> usize {
s.chars().count()
}
fn prefix_by_chars(s: &str, max_chars: usize) -> &str {
if max_chars == 0 {
return "";
}
match s.char_indices().nth(max_chars) {
Some((idx, _)) => &s[..idx],
None => s,
}
}
fn truncate_text(text: &str, max_chars: usize) -> String {
if count_chars(text) <= max_chars {
return text.to_string();
}
let marker_chars = count_chars(TRUNCATION_MARKER);
if max_chars <= marker_chars {
return prefix_by_chars(TRUNCATION_MARKER, max_chars).to_string();
}
let kept = prefix_by_chars(text, max_chars - marker_chars).trim_end();
format!("{kept}{TRUNCATION_MARKER}")
}
pub fn build_embedding_text(unit: &CodeUnit) -> String {
if unit.unit_type == UnitType::RawCode || unit.unit_type == UnitType::Constant {
return truncate_text(&unit.code, MAX_EMBEDDING_TEXT_CHARS);
}
let mut parts = Vec::new();
let type_str = match unit.unit_type {
UnitType::Function => "Function",
UnitType::Method => "Method",
UnitType::Class => "Class",
UnitType::Constant => "Constant",
UnitType::Document => "Document",
UnitType::Section => "Section",
UnitType::RawCode => "Code block",
};
parts.push(format!("{}: {}", type_str, unit.name));
if !unit.signature.is_empty() {
parts.push(format!("Signature: {}", unit.signature));
}
if let Some(parent) = &unit.extends {
if !parent.is_empty() {
parts.push(format!("Extends: {}", parent));
}
}
if let Some(class_name) = &unit.parent_class {
if !class_name.is_empty() {
parts.push(format!("Class: {}", class_name));
}
}
if let Some(doc) = &unit.docstring {
if !doc.is_empty() {
parts.push(format!("Description: {}", doc));
}
}
if !unit.parameters.is_empty() {
parts.push(format!("Parameters: {}", unit.parameters.join(", ")));
}
if let Some(ret) = &unit.return_type {
if !ret.is_empty() {
let label = if unit.unit_type == UnitType::Constant {
"Type"
} else {
"Returns"
};
parts.push(format!("{}: {}", label, ret));
}
}
if !unit.calls.is_empty() {
parts.push(format!("Calls: {}", unit.calls.join(", ")));
}
if !unit.called_by.is_empty() {
parts.push(format!("Called by: {}", unit.called_by.join(", ")));
}
if !unit.variables.is_empty() {
parts.push(format!("Variables: {}", unit.variables.join(", ")));
}
if !unit.imports.is_empty() {
parts.push(format!("Uses: {}", unit.imports.join(", ")));
}
let file_part = format!(
"File: {}",
normalize_path_for_embedding(&shorten_path(&unit.file))
);
parts.push(file_part);
if !unit.code.is_empty() {
parts.push(format!("Code:\n{}", unit.code));
}
truncate_text(&parts.join("\n"), MAX_EMBEDDING_TEXT_CHARS)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_path_separators() {
assert_eq!(
normalize_path_for_embedding("src/parser/mod.rs"),
"src parser mod mod.rs"
);
}
#[test]
fn test_normalize_backslash_separators() {
assert_eq!(
normalize_path_for_embedding("src\\parser\\mod.rs"),
"src parser mod mod.rs"
);
}
#[test]
fn test_normalize_underscores() {
assert_eq!(
normalize_path_for_embedding("my_file_name.py"),
"my file name my_file_name.py"
);
}
#[test]
fn test_normalize_hyphens() {
assert_eq!(
normalize_path_for_embedding("my-file-name.py"),
"my file name my-file-name.py"
);
}
#[test]
fn test_normalize_camel_case() {
assert_eq!(
normalize_path_for_embedding("MyClassName.ts"),
"my class name MyClassName.ts"
);
}
#[test]
fn test_normalize_camel_case_lowercase_start() {
assert_eq!(
normalize_path_for_embedding("myClassName.ts"),
"my class name myClassName.ts"
);
}
#[test]
fn test_normalize_combined() {
assert_eq!(
normalize_path_for_embedding("src/utils/HttpClientHelper.rs"),
"src utils http client helper HttpClientHelper.rs"
);
}
#[test]
fn test_normalize_snake_case_path() {
assert_eq!(
normalize_path_for_embedding("src/my_module/file_utils.py"),
"src my module file utils file_utils.py"
);
}
#[test]
fn test_normalize_mixed_separators() {
assert_eq!(
normalize_path_for_embedding("my_great-file.rs"),
"my great file my_great-file.rs"
);
}
#[test]
fn test_normalize_empty_string() {
assert_eq!(normalize_path_for_embedding(""), " ");
}
#[test]
fn test_normalize_simple_filename() {
assert_eq!(normalize_path_for_embedding("main.rs"), "main main.rs");
}
#[test]
fn test_build_embedding_text_truncates_raw_code_early() {
let mut unit = CodeUnit::new(
"raw".to_string(),
"src/huge.rs".into(),
1,
1,
crate::parser::Language::Rust,
UnitType::RawCode,
None,
);
unit.code = "x".repeat(MAX_EMBEDDING_TEXT_CHARS + 100);
let text = build_embedding_text(&unit);
assert_eq!(count_chars(&text), MAX_EMBEDDING_TEXT_CHARS);
assert!(text.contains("[...truncated...]"));
}
#[test]
fn test_build_embedding_text_puts_file_before_code_and_truncates_tail() {
let mut unit = CodeUnit::new(
"huge_fn".to_string(),
"src/some/deep/module/very_long_name.rs".into(),
1,
10,
crate::parser::Language::Rust,
UnitType::Function,
None,
);
unit.signature = "fn huge_fn()".to_string();
unit.code = "a".repeat(MAX_EMBEDDING_TEXT_CHARS + 500);
let text = build_embedding_text(&unit);
assert!(count_chars(&text) <= MAX_EMBEDDING_TEXT_CHARS);
assert!(text.contains("[...truncated...]"));
assert!(text.contains("File: "));
assert!(text.contains("File: some deep module very long name very_long_name.rs"));
let file_idx = text.find("File: ").unwrap();
let code_idx = text.find("Code:\n").unwrap();
assert!(file_idx < code_idx);
}
#[test]
fn test_normalize_consecutive_separators() {
assert_eq!(
normalize_path_for_embedding("my__file--name.rs"),
"my file name my__file--name.rs"
);
}
#[test]
fn test_normalize_no_extension() {
assert_eq!(
normalize_path_for_embedding("src/Makefile"),
"src makefile Makefile"
);
}
}