1use crate::error::{HeliosError, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use uuid::Uuid;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct Document {
24 pub id: String,
26 pub text: String,
28 pub metadata: HashMap<String, serde_json::Value>,
30 pub timestamp: String,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct SearchResult {
37 pub id: String,
39 pub score: f64,
41 pub text: String,
43 pub metadata: Option<HashMap<String, serde_json::Value>>,
45}
46
47#[async_trait]
53pub trait EmbeddingProvider: Send + Sync {
54 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
56
57 fn dimension(&self) -> usize;
59}
60
61#[async_trait]
67pub trait VectorStore: Send + Sync {
68 async fn initialize(&self, dimension: usize) -> Result<()>;
70
71 async fn add(
73 &self,
74 id: &str,
75 embedding: Vec<f32>,
76 text: &str,
77 metadata: HashMap<String, serde_json::Value>,
78 ) -> Result<()>;
79
80 async fn search(&self, query_embedding: Vec<f32>, limit: usize) -> Result<Vec<SearchResult>>;
82
83 async fn delete(&self, id: &str) -> Result<()>;
85
86 async fn clear(&self) -> Result<()>;
88
89 async fn count(&self) -> Result<usize>;
91}
92
93pub struct OpenAIEmbeddings {
99 api_url: String,
100 api_key: String,
101 model: String,
102 client: Client,
103}
104
105#[derive(Debug, Serialize)]
106struct OpenAIEmbeddingRequest {
107 input: String,
108 model: String,
109}
110
111#[derive(Debug, Deserialize)]
112struct OpenAIEmbeddingResponse {
113 data: Vec<OpenAIEmbeddingData>,
114}
115
116#[derive(Debug, Deserialize)]
117struct OpenAIEmbeddingData {
118 embedding: Vec<f32>,
119}
120
121impl OpenAIEmbeddings {
122 pub fn new(api_url: impl Into<String>, api_key: impl Into<String>) -> Self {
124 Self {
125 api_url: api_url.into(),
126 api_key: api_key.into(),
127 model: "text-embedding-ada-002".to_string(),
128 client: Client::new(),
129 }
130 }
131
132 pub fn with_model(
134 api_url: impl Into<String>,
135 api_key: impl Into<String>,
136 model: impl Into<String>,
137 ) -> Self {
138 Self {
139 api_url: api_url.into(),
140 api_key: api_key.into(),
141 model: model.into(),
142 client: Client::new(),
143 }
144 }
145}
146
147#[async_trait]
148impl EmbeddingProvider for OpenAIEmbeddings {
149 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
150 let request = OpenAIEmbeddingRequest {
151 input: text.to_string(),
152 model: self.model.clone(),
153 };
154
155 let response = self
156 .client
157 .post(&self.api_url)
158 .header("Authorization", format!("Bearer {}", self.api_key))
159 .json(&request)
160 .send()
161 .await
162 .map_err(|e| HeliosError::ToolError(format!("Embedding API error: {}", e)))?;
163
164 if !response.status().is_success() {
165 let error_text = response
166 .text()
167 .await
168 .unwrap_or_else(|_| "Unknown error".to_string());
169 return Err(HeliosError::ToolError(format!(
170 "Embedding API failed: {}",
171 error_text
172 )));
173 }
174
175 let embedding_response: OpenAIEmbeddingResponse = response.json().await.map_err(|e| {
176 HeliosError::ToolError(format!("Failed to parse embedding response: {}", e))
177 })?;
178
179 embedding_response
180 .data
181 .into_iter()
182 .next()
183 .map(|d| d.embedding)
184 .ok_or_else(|| HeliosError::ToolError("No embedding returned".to_string()))
185 }
186
187 fn dimension(&self) -> usize {
188 if self.model == "text-embedding-3-large" {
192 3072
193 } else {
194 1536
195 }
196 }
197}
198
199pub struct InMemoryVectorStore {
205 documents:
206 std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<String, StoredDocument>>>,
207}
208
209#[derive(Debug, Clone)]
210struct StoredDocument {
211 id: String,
212 embedding: Vec<f32>,
213 text: String,
214 metadata: HashMap<String, serde_json::Value>,
215}
216
217impl InMemoryVectorStore {
218 pub fn new() -> Self {
220 Self {
221 documents: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
222 }
223 }
224}
225
226impl Default for InMemoryVectorStore {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
234 if a.len() != b.len() {
235 return 0.0;
236 }
237
238 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
239 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
240 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
241
242 if norm_a == 0.0 || norm_b == 0.0 {
243 return 0.0;
244 }
245
246 (dot_product / (norm_a * norm_b)) as f64
247}
248
249#[async_trait]
250impl VectorStore for InMemoryVectorStore {
251 async fn initialize(&self, _dimension: usize) -> Result<()> {
252 Ok(())
254 }
255
256 async fn add(
257 &self,
258 id: &str,
259 embedding: Vec<f32>,
260 text: &str,
261 metadata: HashMap<String, serde_json::Value>,
262 ) -> Result<()> {
263 let mut docs = self.documents.write().await;
264
265 docs.insert(
267 id.to_string(),
268 StoredDocument {
269 id: id.to_string(),
270 embedding,
271 text: text.to_string(),
272 metadata,
273 },
274 );
275
276 Ok(())
277 }
278
279 async fn search(&self, query_embedding: Vec<f32>, limit: usize) -> Result<Vec<SearchResult>> {
280 let docs = self.documents.read().await;
281
282 if docs.is_empty() {
283 return Ok(Vec::new());
284 }
285
286 let mut results: Vec<(String, f64)> = docs
288 .iter()
289 .map(|(id, doc)| {
290 let similarity = cosine_similarity(&query_embedding, &doc.embedding);
291 (id.clone(), similarity)
292 })
293 .collect();
294
295 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
297
298 let top_results: Vec<SearchResult> = results
300 .into_iter()
301 .take(limit)
302 .filter_map(|(id, score)| {
303 docs.get(&id).map(|doc| SearchResult {
304 id: doc.id.clone(),
305 score,
306 text: doc.text.clone(),
307 metadata: Some(doc.metadata.clone()),
308 })
309 })
310 .collect();
311
312 Ok(top_results)
313 }
314
315 async fn delete(&self, id: &str) -> Result<()> {
316 let mut docs = self.documents.write().await;
317 docs.remove(id);
318 Ok(())
319 }
320
321 async fn clear(&self) -> Result<()> {
322 let mut docs = self.documents.write().await;
323 docs.clear();
324 Ok(())
325 }
326
327 async fn count(&self) -> Result<usize> {
328 let docs = self.documents.read().await;
329 Ok(docs.len())
330 }
331}
332
333pub struct QdrantVectorStore {
339 qdrant_url: String,
340 collection_name: String,
341 client: Client,
342}
343
344#[derive(Debug, Serialize, Deserialize)]
345struct QdrantPoint {
346 id: String,
347 vector: Vec<f32>,
348 payload: HashMap<String, serde_json::Value>,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
352struct QdrantSearchRequest {
353 vector: Vec<f32>,
354 limit: usize,
355 with_payload: bool,
356 with_vector: bool,
357}
358
359#[derive(Debug, Serialize, Deserialize)]
360struct QdrantSearchResponse {
361 result: Vec<QdrantSearchResult>,
362}
363
364#[derive(Debug, Serialize, Deserialize)]
365struct QdrantSearchResult {
366 id: String,
367 score: f64,
368 payload: Option<HashMap<String, serde_json::Value>>,
369}
370
371impl QdrantVectorStore {
372 pub fn new(qdrant_url: impl Into<String>, collection_name: impl Into<String>) -> Self {
374 Self {
375 qdrant_url: qdrant_url.into(),
376 collection_name: collection_name.into(),
377 client: Client::new(),
378 }
379 }
380}
381
382#[async_trait]
383impl VectorStore for QdrantVectorStore {
384 async fn initialize(&self, dimension: usize) -> Result<()> {
385 let collection_url = format!("{}/collections/{}", self.qdrant_url, self.collection_name);
386
387 let response = self.client.get(&collection_url).send().await;
389
390 if response.is_ok() && response.unwrap().status().is_success() {
391 return Ok(()); }
393
394 let create_payload = serde_json::json!({
396 "vectors": {
397 "size": dimension,
398 "distance": "Cosine"
399 }
400 });
401
402 let response = self
403 .client
404 .put(&collection_url)
405 .json(&create_payload)
406 .send()
407 .await
408 .map_err(|e| HeliosError::ToolError(format!("Failed to create collection: {}", e)))?;
409
410 if !response.status().is_success() {
411 let error_text = response
412 .text()
413 .await
414 .unwrap_or_else(|_| "Unknown error".to_string());
415 return Err(HeliosError::ToolError(format!(
416 "Collection creation failed: {}",
417 error_text
418 )));
419 }
420
421 Ok(())
422 }
423
424 async fn add(
425 &self,
426 id: &str,
427 embedding: Vec<f32>,
428 text: &str,
429 metadata: HashMap<String, serde_json::Value>,
430 ) -> Result<()> {
431 let mut payload = metadata;
432 payload.insert("text".to_string(), serde_json::json!(text));
433 payload.insert(
434 "timestamp".to_string(),
435 serde_json::json!(chrono::Utc::now().to_rfc3339()),
436 );
437
438 let point = QdrantPoint {
439 id: id.to_string(),
440 vector: embedding,
441 payload,
442 };
443
444 let upsert_url = format!(
445 "{}/collections/{}/points",
446 self.qdrant_url, self.collection_name
447 );
448 let upsert_payload = serde_json::json!({
449 "points": [point]
450 });
451
452 let response = self
453 .client
454 .put(&upsert_url)
455 .json(&upsert_payload)
456 .send()
457 .await
458 .map_err(|e| HeliosError::ToolError(format!("Failed to upload document: {}", e)))?;
459
460 if !response.status().is_success() {
461 let error_text = response
462 .text()
463 .await
464 .unwrap_or_else(|_| "Unknown error".to_string());
465 return Err(HeliosError::ToolError(format!(
466 "Document upload failed: {}",
467 error_text
468 )));
469 }
470
471 Ok(())
472 }
473
474 async fn search(&self, query_embedding: Vec<f32>, limit: usize) -> Result<Vec<SearchResult>> {
475 let search_url = format!(
476 "{}/collections/{}/points/search",
477 self.qdrant_url, self.collection_name
478 );
479 let search_request = QdrantSearchRequest {
480 vector: query_embedding,
481 limit,
482 with_payload: true,
483 with_vector: false,
484 };
485
486 let response = self
487 .client
488 .post(&search_url)
489 .json(&search_request)
490 .send()
491 .await
492 .map_err(|e| HeliosError::ToolError(format!("Search failed: {}", e)))?;
493
494 if !response.status().is_success() {
495 let error_text = response
496 .text()
497 .await
498 .unwrap_or_else(|_| "Unknown error".to_string());
499 return Err(HeliosError::ToolError(format!(
500 "Search request failed: {}",
501 error_text
502 )));
503 }
504
505 let search_response: QdrantSearchResponse = response.json().await.map_err(|e| {
506 HeliosError::ToolError(format!("Failed to parse search response: {}", e))
507 })?;
508
509 let results: Vec<SearchResult> = search_response
510 .result
511 .into_iter()
512 .filter_map(|r| {
513 r.payload.and_then(|p| {
514 p.get("text").and_then(|t| t.as_str()).map(|text| {
515 let mut metadata = p.clone();
516 metadata.remove("text");
517 SearchResult {
518 id: r.id,
519 score: r.score,
520 text: text.to_string(),
521 metadata: Some(metadata),
522 }
523 })
524 })
525 })
526 .collect();
527
528 Ok(results)
529 }
530
531 async fn delete(&self, id: &str) -> Result<()> {
532 let delete_url = format!(
533 "{}/collections/{}/points/delete",
534 self.qdrant_url, self.collection_name
535 );
536 let delete_payload = serde_json::json!({
537 "points": [id]
538 });
539
540 let response = self
541 .client
542 .post(&delete_url)
543 .json(&delete_payload)
544 .send()
545 .await
546 .map_err(|e| HeliosError::ToolError(format!("Delete failed: {}", e)))?;
547
548 if !response.status().is_success() {
549 let error_text = response
550 .text()
551 .await
552 .unwrap_or_else(|_| "Unknown error".to_string());
553 return Err(HeliosError::ToolError(format!(
554 "Delete request failed: {}",
555 error_text
556 )));
557 }
558
559 Ok(())
560 }
561
562 async fn clear(&self) -> Result<()> {
563 let delete_url = format!("{}/collections/{}", self.qdrant_url, self.collection_name);
564
565 let response = self
566 .client
567 .delete(&delete_url)
568 .send()
569 .await
570 .map_err(|e| HeliosError::ToolError(format!("Clear failed: {}", e)))?;
571
572 if !response.status().is_success() {
573 let error_text = response
574 .text()
575 .await
576 .unwrap_or_else(|_| "Unknown error".to_string());
577 return Err(HeliosError::ToolError(format!(
578 "Clear collection failed: {}",
579 error_text
580 )));
581 }
582
583 Ok(())
584 }
585
586 async fn count(&self) -> Result<usize> {
587 let count_url = format!("{}/collections/{}", self.qdrant_url, self.collection_name);
588
589 let response = self
590 .client
591 .get(&count_url)
592 .send()
593 .await
594 .map_err(|e| HeliosError::ToolError(format!("Count failed: {}", e)))?;
595
596 if !response.status().is_success() {
597 return Ok(0);
598 }
599
600 let info: serde_json::Value = response.json().await.map_err(|e| {
602 HeliosError::ToolError(format!("Failed to parse collection info: {}", e))
603 })?;
604
605 let count = info
606 .get("result")
607 .and_then(|r| r.get("points_count"))
608 .and_then(|c| c.as_u64())
609 .unwrap_or(0) as usize;
610
611 Ok(count)
612 }
613}
614
615pub struct RAGSystem {
621 embedding_provider: Box<dyn EmbeddingProvider>,
622 vector_store: Box<dyn VectorStore>,
623 initialized: std::sync::Arc<tokio::sync::RwLock<bool>>,
624}
625
626impl RAGSystem {
627 pub fn new(
629 embedding_provider: Box<dyn EmbeddingProvider>,
630 vector_store: Box<dyn VectorStore>,
631 ) -> Self {
632 Self {
633 embedding_provider,
634 vector_store,
635 initialized: std::sync::Arc::new(tokio::sync::RwLock::new(false)),
636 }
637 }
638
639 async fn ensure_initialized(&self) -> Result<()> {
641 let is_initialized = *self.initialized.read().await;
642 if !is_initialized {
643 let mut init_guard = self.initialized.write().await;
644 if !*init_guard {
645 let dimension = self.embedding_provider.dimension();
646 self.vector_store.initialize(dimension).await?;
647 *init_guard = true;
648 }
649 }
650 Ok(())
651 }
652
653 pub async fn add_document(
655 &self,
656 text: &str,
657 metadata: Option<HashMap<String, serde_json::Value>>,
658 ) -> Result<String> {
659 self.ensure_initialized().await?;
660
661 let id = Uuid::new_v4().to_string();
662 let embedding = self.embedding_provider.embed(text).await?;
663
664 let mut meta = metadata.unwrap_or_default();
665 meta.insert(
666 "timestamp".to_string(),
667 serde_json::json!(chrono::Utc::now().to_rfc3339()),
668 );
669
670 self.vector_store.add(&id, embedding, text, meta).await?;
671
672 Ok(id)
673 }
674
675 pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
677 self.ensure_initialized().await?;
678
679 let query_embedding = self.embedding_provider.embed(query).await?;
680 self.vector_store.search(query_embedding, limit).await
681 }
682
683 pub async fn delete_document(&self, id: &str) -> Result<()> {
685 self.vector_store.delete(id).await
686 }
687
688 pub async fn clear(&self) -> Result<()> {
690 self.vector_store.clear().await
691 }
692
693 pub async fn count(&self) -> Result<usize> {
695 self.vector_store.count().await
696 }
697}