use crate::local_config::LocalConfig;
use serde::{Deserialize, Serialize};
use std::io::{BufRead, BufReader, Write};
use std::net::TcpStream;
use std::sync::atomic::{AtomicBool, Ordering};
const DEFAULT_HOST: &str = "localhost";
const DEFAULT_PORT: u16 = 11434;
const DEFAULT_PATH: &str = "/v1/embeddings";
const DEFAULT_MODEL: &str = "qwen3-embedding:8b";
const DEFAULT_DIMENSIONS: usize = 4096;
static WARNED_THIS_SESSION: AtomicBool = AtomicBool::new(false);
#[derive(Debug, Serialize)]
struct EmbeddingRequest<'a> {
model: &'a str,
input: Vec<&'a str>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}
#[derive(Debug, thiserror::Error)]
pub enum EmbeddingError {
#[error("HTTP request failed: {0}")]
Http(#[from] std::io::Error),
#[error("API returned error status {status}: {body}")]
Api { status: u16, body: String },
#[error("Failed to parse embedding response: {0}")]
Parse(String),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Empty response from API")]
EmptyResponse,
}
pub struct EmbeddingClient {
host: String,
port: u16,
path: String,
model: String,
dimensions: usize,
}
impl EmbeddingClient {
pub fn from_config(config: &LocalConfig, warn_on_error: bool) -> Option<Self> {
let (host, port, path) = if let Some(ref base_url) = config.embeddings.base_url {
parse_base_url(base_url)
} else {
(
DEFAULT_HOST.to_string(),
DEFAULT_PORT,
DEFAULT_PATH.to_string(),
)
};
let model = config
.embeddings
.model
.clone()
.unwrap_or_else(|| DEFAULT_MODEL.to_string());
let dimensions = config.embeddings.dimensions.unwrap_or(DEFAULT_DIMENSIONS);
let client = Self {
host: host.clone(),
port,
path,
model,
dimensions,
};
match client.embed("test") {
Ok(_) => Some(client),
Err(e) => {
if warn_on_error && !WARNED_THIS_SESSION.swap(true, Ordering::SeqCst) {
eprintln!(
"Warning: Embedding service unavailable at {}:{}: {}",
host, port, e
);
eprintln!("Semantic search features will be disabled.");
}
None
}
}
}
pub fn model(&self) -> &str {
&self.model
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
let embeddings = self.embed_batch(&[text])?;
embeddings
.into_iter()
.next()
.ok_or(EmbeddingError::EmptyResponse)
}
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
if texts.is_empty() {
return Ok(Vec::new());
}
let request = EmbeddingRequest {
model: &self.model,
input: texts.to_vec(),
};
let body =
serde_json::to_string(&request).map_err(|e| EmbeddingError::Parse(e.to_string()))?;
let response_body = self.http_post(&body)?;
let response: EmbeddingResponse = serde_json::from_str(&response_body)
.map_err(|e| EmbeddingError::Parse(format!("{}: {}", e, response_body)))?;
if let Some(data) = response.data.first() {
if data.embedding.len() != self.dimensions {
return Err(EmbeddingError::DimensionMismatch {
expected: self.dimensions,
actual: data.embedding.len(),
});
}
}
Ok(response.data.into_iter().map(|d| d.embedding).collect())
}
fn http_post(&self, body: &str) -> Result<String, EmbeddingError> {
let addr = format!("{}:{}", self.host, self.port);
let mut stream = TcpStream::connect(&addr)?;
stream.set_read_timeout(Some(std::time::Duration::from_secs(30)))?;
let request = format!(
"POST {} HTTP/1.1\r\nHost: {}:{}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
self.path, self.host, self.port, body.len(), body
);
stream.write_all(request.as_bytes())?;
let mut reader = BufReader::new(stream);
let mut status_line = String::new();
reader.read_line(&mut status_line)?;
let status_code = parse_status_code(&status_line);
let mut content_length: Option<usize> = None;
loop {
let mut header = String::new();
reader.read_line(&mut header)?;
if header == "\r\n" || header.is_empty() {
break;
}
if let Some(value) = header.strip_prefix("Content-Length:").or_else(|| header.strip_prefix("content-length:")) {
content_length = value.trim().parse().ok();
}
}
use std::io::Read;
let mut response_body = String::new();
if let Some(len) = content_length {
let mut buf = vec![0u8; len];
reader.read_exact(&mut buf)?;
response_body = String::from_utf8(buf)
.map_err(|e| EmbeddingError::Parse(format!("invalid UTF-8 in response: {}", e)))?;
} else {
reader.read_to_string(&mut response_body)?;
}
if status_code != 200 {
return Err(EmbeddingError::Api {
status: status_code,
body: response_body,
});
}
Ok(response_body)
}
}
fn parse_base_url(base_url: &str) -> (String, u16, String) {
if base_url.starts_with("https://") {
eprintln!(
"Warning: HTTPS is not supported for embeddings (no TLS). Treating as HTTP: {}",
base_url
);
}
let without_scheme = base_url
.strip_prefix("http://")
.or_else(|| base_url.strip_prefix("https://"))
.unwrap_or(base_url);
let (host_port, path) = if let Some(idx) = without_scheme.find('/') {
let path = format!(
"{}/embeddings",
&without_scheme[idx..].trim_end_matches('/')
);
(&without_scheme[..idx], path)
} else {
(without_scheme, DEFAULT_PATH.to_string())
};
if let Some(colon) = host_port.rfind(':') {
let host = host_port[..colon].to_string();
let port = host_port[colon + 1..].parse().unwrap_or(DEFAULT_PORT);
(host, port, path)
} else {
(host_port.to_string(), DEFAULT_PORT, path)
}
}
fn parse_status_code(status_line: &str) -> u16 {
status_line
.split_whitespace()
.nth(1)
.and_then(|s| s.parse().ok())
.unwrap_or(0)
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot_product = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (ai, bi) in a.iter().zip(b.iter()) {
dot_product += ai * bi;
norm_a += ai * ai;
norm_b += bi * bi;
}
let denominator = norm_a.sqrt() * norm_b.sqrt();
if denominator == 0.0 {
0.0
} else {
dot_product / denominator
}
}
pub fn prepare_problem_text(title: &str, description: &str) -> String {
format!("{}\n\n{}", title, description).trim().to_string()
}
pub fn prepare_solution_text(title: &str, approach: &str) -> String {
format!("{}\n\n{}", title, approach).trim().to_string()
}
pub fn prepare_critique_text(title: &str, argument: &str) -> String {
format!("{}\n\n{}", title, argument).trim().to_string()
}
pub fn prepare_milestone_text(title: &str, description: &str) -> String {
format!("{}\n\n{}", title, description).trim().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 0.0001);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 0.0001);
}
#[test]
fn test_cosine_similarity_opposite() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![-1.0, -2.0, -3.0];
let sim = cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 0.0001);
}
#[test]
fn test_cosine_similarity_empty() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
let sim = cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn test_cosine_similarity_different_lengths() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&a, &b);
assert_eq!(sim, 0.0);
}
#[test]
fn test_prepare_problem_text() {
let text = prepare_problem_text("Title", "Description");
assert_eq!(text, "Title\n\nDescription");
}
#[test]
fn test_prepare_solution_text() {
let text = prepare_solution_text("Title", "Approach");
assert_eq!(text, "Title\n\nApproach");
}
#[test]
fn test_prepare_critique_text() {
let text = prepare_critique_text("Title", "Argument");
assert_eq!(text, "Title\n\nArgument");
}
#[test]
fn test_prepare_milestone_text() {
let text = prepare_milestone_text("Title", "Description");
assert_eq!(text, "Title\n\nDescription");
}
#[test]
fn test_parse_base_url_default() {
let (host, port, path) = parse_base_url("http://localhost:11434/v1");
assert_eq!(host, "localhost");
assert_eq!(port, 11434);
assert_eq!(path, "/v1/embeddings");
}
#[test]
fn test_parse_base_url_custom_port() {
let (host, port, _) = parse_base_url("http://localhost:9999/v1");
assert_eq!(host, "localhost");
assert_eq!(port, 9999);
}
#[test]
fn test_parse_status_code() {
assert_eq!(parse_status_code("HTTP/1.1 200 OK\r\n"), 200);
assert_eq!(parse_status_code("HTTP/1.1 404 Not Found\r\n"), 404);
assert_eq!(
parse_status_code("HTTP/1.1 500 Internal Server Error\r\n"),
500
);
}
}