use crate::tokenizer::get_tokenizer;
use bm25::Tokenizer;
use serde::{Deserialize, Serialize};
use std::fs;
#[derive(Serialize, Deserialize)]
pub struct FileVectorStore {
pub rag_vectors: Vec<RagEmbeddedFileChunk>,
pub bm25_files: Vec<Bm25EmbeddedFile>,
pub pattern: String,
pub chunk_size: usize,
pub overlap_size: usize,
pub embed_metadata: bool,
pub date: u64,
pub bm25_avgdl: f32,
pub doc_count: usize,
}
#[derive(Serialize, Deserialize)]
pub struct Bm25EmbeddedFile {
pub filename: String,
pub bm25_indices: Vec<u32>,
pub bm25_values: Vec<f32>,
pub tokens: Vec<String>,
pub token_count: usize,
pub last_modified: u64,
pub has_metadata: bool, }
#[derive(Serialize, Deserialize)]
pub struct RagEmbeddedFileChunk {
pub filename: String,
pub vector: Vec<f32>,
pub last_modified: u64,
pub chunk_offset: usize, pub chunk_size: usize, pub is_full_file: bool, pub has_metadata: bool, }
pub async fn scan_files(
pattern: &str,
api_key: &str,
chunk_size: usize,
overlap_size: usize,
embed_metadata: bool,
) -> Result<(), Box<dyn std::error::Error>> {
if chunk_size > 0 && overlap_size >= chunk_size {
return Err("overlap_size must be less than chunk_size".into());
}
println!("Scanning for files matching pattern: {}", pattern);
let mut store = FileVectorStore {
rag_vectors: Vec::new(),
bm25_files: Vec::new(),
pattern: pattern.to_string(),
chunk_size: chunk_size,
overlap_size: overlap_size,
embed_metadata: embed_metadata,
date: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_secs(),
bm25_avgdl: 0.0,
doc_count: 0,
};
let mut total_tokens = 0;
let mut doc_count = 0;
fn create_chunks(
content: &str,
chunk_size: usize,
overlap_size: usize,
) -> Vec<(usize, String)> {
if chunk_size == 0 {
return vec![(0, content.to_string())];
}
let mut chunks = Vec::new();
let content_len = content.len();
let mut offset = 0;
while offset < content_len {
let end = (offset + chunk_size).min(content_len);
let chunk = content[offset..end].to_string();
chunks.push((offset, chunk));
if end == content_len {
break;
}
offset += chunk_size - overlap_size;
}
chunks
}
async fn process_file(
path: &std::path::Path,
path_str: &str,
store: &mut FileVectorStore,
api_key: &str,
chunk_size: usize,
overlap_size: usize,
embed_metadata: bool,
_total_tokens: &mut usize,
_doc_count: &mut usize,
) -> Result<(), Box<dyn std::error::Error>> {
println!("Processing: {}", path_str);
let contents = fs::read_to_string(path)?;
let metadata = fs::metadata(path)?;
let last_modified = metadata
.modified()?
.duration_since(std::time::UNIX_EPOCH)?
.as_secs();
let content_to_embed = if embed_metadata {
let size = metadata.len();
crate::metadata::prepend_metadata(path_str, last_modified, size, &contents)
} else {
contents.clone()
};
let bm25_vec =
crate::bm25_embedder::create_bm25_vector(&content_to_embed, store.bm25_avgdl);
let tokens = get_tokenizer().tokenize(&content_to_embed);
store.bm25_files.push(Bm25EmbeddedFile {
filename: path_str.to_string(),
bm25_indices: bm25_vec.indices,
bm25_values: bm25_vec.values,
tokens: tokens.clone(),
token_count: tokens.len(),
last_modified,
has_metadata: embed_metadata,
});
let chunks = create_chunks(&contents, chunk_size, overlap_size);
for (offset, chunk_content) in chunks {
let chunk_to_embed = if embed_metadata {
let size = metadata.len();
crate::metadata::prepend_metadata(path_str, last_modified, size, &chunk_content)
} else {
chunk_content.clone()
};
match crate::openai::get_embedding(&chunk_to_embed, api_key).await {
Ok(embedding) => {
store.rag_vectors.push(RagEmbeddedFileChunk {
filename: path_str.to_string(),
vector: embedding,
last_modified,
chunk_offset: offset,
chunk_size: chunk_content.len(),
is_full_file: chunk_size == 0,
has_metadata: embed_metadata,
});
if chunk_size > 0 {
println!("Got embedding for {} (chunk offset: {})", path_str, offset);
} else {
println!("Got embedding for {}", path_str);
}
}
Err(e) => {
eprintln!("Error getting embedding for {}: {}", path_str, e);
return Err(e);
}
}
}
Ok(())
}
let matching_files = crate::files::find_matching_files(pattern);
println!("Calculating average document length...");
for path in matching_files.iter() {
let relative_path = path.strip_prefix(std::env::current_dir()?).unwrap_or(&path);
let path_str = relative_path.to_string_lossy().to_string();
if path_str.ends_with(".luckyshot.file.vectors.v1") {
continue;
}
let contents = fs::read_to_string(path)?;
let tokens = get_tokenizer().tokenize(&contents);
total_tokens += tokens.len();
doc_count += 1;
}
store.bm25_avgdl = total_tokens as f32 / doc_count as f32;
store.doc_count = doc_count;
println!("Average document length: {:.2}", store.bm25_avgdl);
for path in matching_files {
let relative_path = path.strip_prefix(std::env::current_dir()?).unwrap_or(&path);
let path_str = relative_path.to_string_lossy().to_string();
if path_str.ends_with(".luckyshot.file.vectors.v1") {
continue;
}
process_file(
&path,
&path_str,
&mut store,
api_key,
chunk_size,
overlap_size,
embed_metadata,
&mut total_tokens,
&mut doc_count,
)
.await?;
}
match serde_json::to_string_pretty(&store) {
Ok(json) => {
if let Err(e) = fs::write(".luckyshot.file.vectors.v1", json) {
eprintln!("Error writing vectors file: {}", e);
} else {
println!(
"Successfully saved vectors for {} chunks",
store.rag_vectors.len()
);
use colored::*;
use figlet_rs::FIGfont;
let standard_font = FIGfont::standard().unwrap();
let figure = standard_font.convert("Yee-haw!").unwrap();
println!("\n{}", figure.to_string().bright_yellow());
}
}
Err(e) => eprintln!("Error serializing vectors: {}", e),
}
Ok(())
}