#![allow(dead_code)]
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub mod index;
pub mod search;
pub mod storage;
pub use index::SemanticIndex;
pub use search::SemanticSearch;
pub use storage::EmbeddingStorage;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticConfig {
pub embedding_dimension: usize,
pub chunk_size: usize,
pub chunk_overlap: usize,
pub similarity_threshold: f32,
}
impl Default for SemanticConfig {
fn default() -> Self {
Self {
embedding_dimension: 384,
chunk_size: 512,
chunk_overlap: 50,
similarity_threshold: 0.7,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeEmbedding {
pub id: String,
pub content: String,
pub embedding: Vec<f32>,
pub metadata: EmbeddingMetadata,
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingMetadata {
pub source_file: String,
pub repository: String,
pub language: String,
pub start_line: usize,
pub end_line: usize,
pub function_name: Option<String>,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub embedding: CodeEmbedding,
pub score: f32,
pub highlights: Vec<String>,
}
pub struct EmbeddingGenerator {
config: SemanticConfig,
backend: EmbedderBackend,
#[allow(dead_code)]
vocabulary: HashMap<String, usize>,
}
enum EmbedderBackend {
OpenAI(OpenAIEmbedder),
HashFallback,
}
struct OpenAIEmbedder {
client: reqwest::Client,
api_key: String,
base_url: String,
model: String,
}
#[derive(serde::Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingItem>,
}
#[derive(serde::Deserialize)]
struct EmbeddingItem {
embedding: Vec<f32>,
}
#[derive(serde::Deserialize)]
struct EmbeddingApiError {
error: EmbeddingApiErrorInner,
}
#[derive(serde::Deserialize)]
struct EmbeddingApiErrorInner {
message: String,
}
impl EmbeddingGenerator {
pub fn new(config: SemanticConfig) -> Result<Self> {
let backend = match std::env::var("OPENAI_API_KEY") {
Ok(key) if !key.is_empty() => {
tracing::info!(
"semantic search using OpenAI text-embedding-3-small @ {} dims",
config.embedding_dimension
);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?;
let base_url = std::env::var("OPENAI_BASE_URL")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
EmbedderBackend::OpenAI(OpenAIEmbedder {
client,
api_key: key,
base_url,
model: "text-embedding-3-small".to_string(),
})
}
_ => {
tracing::warn!(
"OPENAI_API_KEY not set; semantic search is using a keyword \
hash-bucket fallback (NOT real semantic embeddings). Set \
OPENAI_API_KEY (or OPENAI_BASE_URL for an OpenAI-compatible \
proxy / local model) for real embeddings."
);
EmbedderBackend::HashFallback
}
};
Ok(Self {
config,
backend,
vocabulary: HashMap::new(),
})
}
pub async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
match &self.backend {
EmbedderBackend::OpenAI(b) => {
b.embed_batch(texts, self.config.embedding_dimension).await
}
EmbedderBackend::HashFallback => Ok(texts
.iter()
.map(|t| hash_bucket_embed(t, self.config.embedding_dimension))
.collect()),
}
}
pub async fn embed_single(&self, text: &str) -> Result<Vec<f32>> {
match &self.backend {
EmbedderBackend::OpenAI(b) => {
let mut v = b
.embed_batch(&[text.to_string()], self.config.embedding_dimension)
.await?;
v.pop()
.ok_or_else(|| anyhow::anyhow!("empty embedding response"))
}
EmbedderBackend::HashFallback => {
Ok(hash_bucket_embed(text, self.config.embedding_dimension))
}
}
}
pub fn is_semantic(&self) -> bool {
matches!(self.backend, EmbedderBackend::OpenAI(_))
}
pub fn id(&self) -> String {
let dim = self.config.embedding_dimension;
match &self.backend {
EmbedderBackend::OpenAI(b) => format!("openai:{}@{}", b.model, dim),
EmbedderBackend::HashFallback => format!("hash-bucket@{}", dim),
}
}
pub fn chunk_code(&self, code: &str, file_path: &str) -> Vec<CodeChunk> {
let language = detect_language(file_path);
let lines: Vec<&str> = code.lines().collect();
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let mut start_line = 0;
for (i, line) in lines.iter().enumerate() {
let is_boundary = is_code_boundary(line, &language);
if is_boundary && !current_chunk.is_empty() {
chunks.push(CodeChunk {
content: current_chunk.trim().to_string(),
start_line,
end_line: i,
language: language.clone(),
function_name: extract_function_name(¤t_chunk, &language),
});
current_chunk.clear();
start_line = i;
}
current_chunk.push_str(line);
current_chunk.push('\n');
if current_chunk.len() > self.config.chunk_size {
chunks.push(CodeChunk {
content: current_chunk.trim().to_string(),
start_line,
end_line: i + 1,
language: language.clone(),
function_name: extract_function_name(¤t_chunk, &language),
});
current_chunk.clear();
start_line = i + 1;
}
}
if !current_chunk.is_empty() {
let lang = language.clone();
chunks.push(CodeChunk {
content: current_chunk.trim().to_string(),
start_line,
end_line: lines.len(),
language,
function_name: extract_function_name(¤t_chunk, &lang),
});
}
chunks
}
}
#[derive(Debug, Clone)]
pub struct CodeChunk {
pub content: String,
pub start_line: usize,
pub end_line: usize,
pub language: String,
pub function_name: Option<String>,
}
impl OpenAIEmbedder {
async fn embed_batch(
&self,
texts: &[String],
dimensions: usize,
) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let body = serde_json::json!({
"model": self.model,
"input": texts,
"dimensions": dimensions,
});
let resp = self
.client
.post(format!("{}/embeddings", self.base_url.trim_end_matches('/')))
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let raw = resp.text().await.unwrap_or_default();
let msg = serde_json::from_str::<EmbeddingApiError>(&raw)
.map(|e| e.error.message)
.unwrap_or(raw);
anyhow::bail!("OpenAI embeddings {}: {}", status, msg);
}
let parsed: EmbeddingResponse = resp.json().await?;
if parsed.data.len() != texts.len() {
anyhow::bail!(
"OpenAI returned {} embeddings for {} inputs",
parsed.data.len(),
texts.len()
);
}
Ok(parsed.data.into_iter().map(|d| d.embedding).collect())
}
}
fn hash_bucket_embed(text: &str, dim: usize) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut embedding = vec![0.0f32; dim];
let tokens: Vec<&str> = text
.split_whitespace()
.map(|t| t.trim_matches(|c: char| !c.is_alphanumeric()))
.filter(|t| !t.is_empty() && t.len() > 2)
.collect();
for (i, token) in tokens.iter().enumerate() {
let mut hasher = DefaultHasher::new();
token.to_lowercase().hash(&mut hasher);
let idx = (hasher.finish() as usize) % dim;
let weight = 1.0 / (1.0 + (i as f32 * 0.1));
embedding[idx] += weight;
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut embedding {
*x /= norm;
}
}
embedding
}
fn detect_language(file_path: &str) -> String {
let ext = std::path::Path::new(file_path)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
match ext {
"rs" => "Rust",
"py" => "Python",
"js" => "JavaScript",
"ts" => "TypeScript",
"go" => "Go",
"java" => "Java",
"cpp" | "cc" | "cxx" => "C++",
"c" => "C",
"rb" => "Ruby",
"php" => "PHP",
_ => "Unknown",
}.to_string()
}
fn is_code_boundary(line: &str, language: &str) -> bool {
let line = line.trim();
match language {
"Rust" => line.starts_with("fn ") || line.starts_with("impl ") || line.starts_with("struct "),
"Python" => line.starts_with("def ") || line.starts_with("class "),
"JavaScript" | "TypeScript" => line.starts_with("function ") || line.starts_with("const ") || line.starts_with("class "),
"Go" => line.starts_with("func "),
"Java" => line.contains(" class ") || line.contains(" interface "),
_ => false,
}
}
fn extract_function_name(code: &str, language: &str) -> Option<String> {
let first_line = code.lines().next()?;
match language {
"Rust" => {
first_line.split("fn ").nth(1)?
.split(|c: char| c == '(' || c == '<')
.next()
.map(|s| s.trim().to_string())
}
"Python" => {
first_line.split("def ").nth(1)?
.split('(')
.next()
.map(|s| s.trim().to_string())
}
_ => None,
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn fallback_used_when_no_api_key() {
if std::env::var("OPENAI_API_KEY").is_ok() {
return;
}
let gen = EmbeddingGenerator::new(SemanticConfig::default()).unwrap();
assert!(!gen.is_semantic());
let v = gen.embed_single("hello world").await.unwrap();
assert_eq!(v.len(), 384);
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-3 || norm == 0.0);
}
#[test]
fn hash_bucket_is_deterministic() {
let a = hash_bucket_embed("fn main() {}", 384);
let b = hash_bucket_embed("fn main() {}", 384);
assert_eq!(a, b);
}
#[test]
fn hash_bucket_dimension_is_respected() {
assert_eq!(hash_bucket_embed("test", 128).len(), 128);
assert_eq!(hash_bucket_embed("test", 384).len(), 384);
assert_eq!(hash_bucket_embed("test", 1536).len(), 1536);
}
#[test]
fn cosine_similarity_orthogonal_vectors() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &b)).abs() < 1e-6);
}
#[test]
fn cosine_similarity_identical_vectors() {
let a = vec![1.0, 2.0, 3.0];
assert!((cosine_similarity(&a, &a) - 1.0).abs() < 1e-6);
}
}