use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::error::TldrError;
use crate::TldrResult;
pub const DEFAULT_EMBEDDING_URL: &str = "http://localhost:8765";
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
pub const EMBEDDING_DIM: usize = 1024;
#[derive(Debug, Clone, Serialize)]
struct EmbeddingRequest {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
texts: Option<Vec<String>>,
}
#[derive(Debug, Clone, Deserialize)]
struct _EmbeddingResponse {
#[serde(default)]
_embedding: Vec<f32>,
#[serde(default)]
_embeddings: Vec<Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticResult {
pub doc_id: String,
pub score: f64,
pub line_start: u32,
pub line_end: u32,
pub snippet: String,
}
#[derive(Debug, Clone, Serialize)]
struct SearchRequest {
query: String,
top_k: usize,
project: String,
}
#[derive(Debug, Clone, Deserialize)]
struct _SearchResponse {
_results: Vec<SemanticResult>,
}
#[derive(Debug, Clone)]
pub struct EmbeddingClient {
base_url: String,
_timeout: Duration,
}
impl Default for EmbeddingClient {
fn default() -> Self {
Self::new(DEFAULT_EMBEDDING_URL)
}
}
impl EmbeddingClient {
pub fn new(base_url: &str) -> Self {
Self {
base_url: base_url.trim_end_matches('/').to_string(),
_timeout: DEFAULT_TIMEOUT,
}
}
pub fn with_timeout(base_url: &str, timeout: Duration) -> Self {
Self {
base_url: base_url.trim_end_matches('/').to_string(),
_timeout: timeout,
}
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn is_available(&self) -> bool {
std::net::TcpStream::connect_timeout(&self.parse_address(), Duration::from_secs(1)).is_ok()
}
fn parse_address(&self) -> std::net::SocketAddr {
let url = self
.base_url
.strip_prefix("http://")
.unwrap_or(&self.base_url);
let url = url.strip_prefix("https://").unwrap_or(url);
let (host, port) = if let Some((h, p)) = url.split_once(':') {
(h, p.parse().unwrap_or(8765))
} else {
(url, 8765)
};
use std::net::ToSocketAddrs;
format!("{}:{}", host, port)
.to_socket_addrs()
.ok()
.and_then(|mut addrs| addrs.next())
.unwrap_or_else(|| std::net::SocketAddr::from(([127, 0, 0, 1], port)))
}
pub fn search(
&self,
query: &str,
project: &str,
top_k: usize,
) -> TldrResult<Vec<SemanticResult>> {
if !self.is_available() {
return Err(TldrError::ConnectionFailed(format!(
"Embedding service at {} is not available",
self.base_url
)));
}
let _request = SearchRequest {
query: query.to_string(),
top_k,
project: project.to_string(),
};
Ok(Vec::new())
}
pub fn embed(&self, text: &str) -> TldrResult<Vec<f32>> {
if !self.is_available() {
return Err(TldrError::ConnectionFailed(format!(
"Embedding service at {} is not available",
self.base_url
)));
}
let _request = EmbeddingRequest {
text: text.to_string(),
texts: None,
};
Ok(vec![0.0; EMBEDDING_DIM])
}
pub fn embed_batch(&self, texts: &[String]) -> TldrResult<Vec<Vec<f32>>> {
if !self.is_available() {
return Err(TldrError::ConnectionFailed(format!(
"Embedding service at {} is not available",
self.base_url
)));
}
let _request = EmbeddingRequest {
text: String::new(),
texts: Some(texts.to_vec()),
};
Ok(texts.iter().map(|_| vec![0.0; EMBEDDING_DIM]).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = EmbeddingClient::new("http://localhost:8765");
assert_eq!(client.base_url(), "http://localhost:8765");
}
#[test]
fn test_client_with_trailing_slash() {
let client = EmbeddingClient::new("http://localhost:8765/");
assert_eq!(client.base_url(), "http://localhost:8765");
}
#[test]
fn test_client_unavailable() {
let client = EmbeddingClient::new("http://localhost:59999");
assert!(!client.is_available());
}
#[test]
fn test_search_unavailable_service() {
let client = EmbeddingClient::new("http://localhost:59999");
let result = client.search("query", "project", 10);
assert!(result.is_err());
if let Err(TldrError::ConnectionFailed(msg)) = result {
assert!(msg.contains("not available"));
} else {
panic!("Expected ConnectionFailed error");
}
}
}