use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::vec::Vec;
use std::fs;
#[derive(Debug, Serialize)]
struct OpenAIChatRequest {
model: String,
messages: Vec<ChatMessage>,
temperature: f32,
}
#[derive(Debug, Deserialize)]
struct OpenAIChatResponse {
choices: Vec<OpenAIChatChoice>,
}
#[derive(Debug, Deserialize)]
struct OpenAIChatChoice {
message: ChatMessage,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<ChatMessage>,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<ChatChoice>,
}
#[derive(Debug, Deserialize)]
struct ChatChoice {
message: ChatMessage,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug)]
pub struct FileMatch {
pub filename: String,
pub similarity: f32,
}
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
input: String,
model: String,
}
pub async fn find_related_files(query_embedding: Vec<f32>) -> Vec<String> {
let vectors_content = match fs::read_to_string(".luckyshot.file.vectors.v1") {
Ok(content) => content,
Err(e) => {
eprintln!("Error reading vectors file: {}", e);
return Vec::new();
}
};
let file_embeddings: Vec<crate::scan::FileEmbedding> = match serde_json::from_str(&vectors_content) {
Ok(embeddings) => embeddings,
Err(e) => {
eprintln!("Error parsing vectors file: {}", e);
return Vec::new();
}
};
let mut matches: Vec<FileMatch> = file_embeddings
.iter()
.map(|embedding| {
let similarity = bm25_similarity(&query_embedding, &embedding.vector);
FileMatch {
filename: embedding.filename.clone(),
similarity,
}
})
.collect();
let min_similarity = matches.iter().map(|m| m.similarity).fold(f32::INFINITY, f32::min);
let max_similarity = matches.iter().map(|m| m.similarity).fold(f32::NEG_INFINITY, f32::max);
if (max_similarity - min_similarity).abs() > f32::EPSILON {
for m in &mut matches {
m.similarity = (m.similarity - min_similarity) / (max_similarity - min_similarity);
}
}
matches.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
for m in &matches {
let embedding = file_embeddings.iter()
.find(|e| e.filename == m.filename)
.unwrap();
println!("{:.3},{},{},{},{}",
m.similarity,
m.filename,
if embedding.is_full_file { "full" } else { "chunk" },
embedding.chunk_offset,
embedding.chunk_size
);
}
matches.iter()
.map(|m| m.filename.clone())
.collect()
}
fn bm25_similarity(query: &[f32], doc: &[f32]) -> f32 {
const K1: f32 = 1.5; const B: f32 = 0.75; const EPSILON: f32 = 1e-10;
let avg_dl = doc.len() as f32;
let doc_len = doc.len() as f32;
let mut score = 0.0;
for (q, d) in query.iter().zip(doc.iter()) {
let tf = d.abs();
let idf = (q.abs() + EPSILON).ln();
let numerator = tf * (K1 + 1.0);
let denominator = tf + K1 * (1.0 - B + B * doc_len / avg_dl);
score += idf * numerator / denominator;
}
score
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
model: String,
usage: Usage,
}
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
index: i32,
object: String,
}
#[derive(Debug, Deserialize)]
struct Usage {
prompt_tokens: i32,
total_tokens: i32,
}
pub async fn get_chat_completion(prompt: &str, api_key: &str) -> Result<String, Box<dyn Error>> {
let client = Client::new();
let request = ChatRequest {
model: "gpt-4o-mini".to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: prompt.to_string(),
}],
};
let response = client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
let chat_response: ChatResponse = response.json().await?;
Ok(chat_response.choices[0].message.content.clone())
}
pub async fn get_openai_chat_completion(prompt: &str, system_prompt: &str, api_key: &str) -> Result<String, Box<dyn Error>> {
let client = Client::new();
let request = OpenAIChatRequest {
model: "gpt-4o-mini".to_string(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: system_prompt.to_string(),
},
ChatMessage {
role: "user".to_string(),
content: prompt.to_string(),
},
],
temperature: 0.7,
};
let response = client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
let chat_response: OpenAIChatResponse = response.json().await?;
Ok(chat_response.choices[0].message.content.clone())
}
pub async fn get_embedding(text: &str, api_key: &str) -> Result<Vec<f32>, Box<dyn Error>> {
let client = Client::new();
let request = EmbeddingRequest {
input: text.to_string(),
model: "text-embedding-ada-002".to_string(),
};
let response = client
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
let response_text = response.text().await?;
let embedding_response: EmbeddingResponse = serde_json::from_str(&response_text)?;
Ok(embedding_response.data[0].embedding.clone())
}