1use std::error::Error;
2use std::fmt;
3use std::path::{Path, PathBuf};
4
5#[cfg(test)]
6use std::sync::Arc;
7#[cfg(test)]
8use std::sync::atomic::{AtomicUsize, Ordering};
9
10use backends::lancedb::{Chunk, DocumentRecord, LanceDbBackend};
11use glob::{GlobError, PatternError, glob};
12use lancedb::Error as LanceDbError;
13use serde::de::DeserializeOwned;
14use serde::{Deserialize, Serialize};
15use serde_json::{Value, json};
16use uuid::Uuid;
17
18pub mod backends;
19pub mod chunking;
20pub mod embeddings;
21mod node;
22
23pub use chunking::{ChunkingConfig, ChunkingError, chunk_text};
24pub use embeddings::{EmbeddingError, EmbeddingsConfig, EmbeddingsProvider, OrtEmbedder};
25
26#[derive(Clone, Debug, Eq, PartialEq)]
27pub enum DbEngine {
29 LanceDb {
31 path: PathBuf,
33 vector_dimensions: i32,
35 },
36}
37
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub enum EmbeddingsProviderKind {
41 Ort(EmbeddingsConfig),
43}
44
45#[derive(Debug, Eq, PartialEq)]
46pub enum RKitConfigError {
47 InvalidVectorDimensions(i32),
48 InvalidChunkingConfig(ChunkingError),
49}
50
51impl fmt::Display for RKitConfigError {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self {
54 Self::InvalidVectorDimensions(value) => {
55 write!(
56 f,
57 "vector_dimensions must be greater than zero, got {value}"
58 )
59 }
60 Self::InvalidChunkingConfig(error) => {
61 write!(f, "invalid chunking config: {error}")
62 }
63 }
64 }
65}
66
67impl Error for RKitConfigError {
68 fn source(&self) -> Option<&(dyn Error + 'static)> {
69 match self {
70 Self::InvalidVectorDimensions(_) => None,
71 Self::InvalidChunkingConfig(error) => Some(error),
72 }
73 }
74}
75
76#[derive(Debug)]
77pub enum RKitInitError {
78 DbEngine(LanceDbError),
79 EmbeddingsProvider(EmbeddingError),
80 EmbeddingDimensionMismatch {
81 db_vector_dimensions: i32,
82 embedding_dimensions: usize,
83 },
84}
85
86impl fmt::Display for RKitInitError {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 match self {
89 Self::DbEngine(error) => write!(f, "failed to initialize database engine: {error}"),
90 Self::EmbeddingsProvider(error) => {
91 write!(f, "failed to initialize embeddings provider: {error}")
92 }
93 Self::EmbeddingDimensionMismatch {
94 db_vector_dimensions,
95 embedding_dimensions,
96 } => write!(
97 f,
98 "embedding dimension mismatch: database expects {db_vector_dimensions}, embedder returns {embedding_dimensions}"
99 ),
100 }
101 }
102}
103
104impl Error for RKitInitError {
105 fn source(&self) -> Option<&(dyn Error + 'static)> {
106 match self {
107 Self::DbEngine(error) => Some(error),
108 Self::EmbeddingsProvider(error) => Some(error),
109 Self::EmbeddingDimensionMismatch { .. } => None,
110 }
111 }
112}
113
114const DEFAULT_TOOL_LIMIT: usize = 10;
115const SEMANTIC_SEARCH_TOOL_NAME: &str = "semantic_search";
116const KEYWORD_SEARCH_TOOL_NAME: &str = "keyword_search";
117const LIST_DOCUMENTS_TOOL_NAME: &str = "list_documents";
118const GET_DOCUMENT_TOOL_NAME: &str = "get_document";
119
120#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
121pub struct IngestDocumentResult {
123 pub document_id: String,
125 pub chunk_count: usize,
127}
128
129#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
130pub struct DocumentSummary {
132 pub document_id: String,
134}
135
136#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
137pub struct Document {
139 pub document_id: String,
141 pub content: String,
143}
144
145#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
146pub struct VectorSearchResult {
148 pub document_id: String,
150 pub text: String,
152 pub distance: f32,
154}
155
156#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
157pub struct KeywordSearchResult {
159 pub document_id: String,
161 pub text: String,
163 pub score: f32,
165}
166
167#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
168pub struct ToolDefinition {
170 pub name: String,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 pub description: Option<String>,
175 #[serde(rename = "inputSchema")]
176 pub input_schema: Value,
178}
179
180#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
181pub struct ToolDescriptions {
183 #[serde(skip_serializing_if = "Option::is_none")]
184 pub semantic_search: Option<String>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 pub keyword_search: Option<String>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 pub list_documents: Option<String>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub get_document: Option<String>,
191}
192
193#[derive(Debug)]
194pub enum IngestDocumentError {
195 NotInitialized,
196 EmptyContent,
197 FileRead {
198 path: PathBuf,
199 source: std::io::Error,
200 },
201 InvalidGlobPattern(PatternError),
202 Glob(GlobError),
203 NoFilesMatched {
204 pattern: String,
205 },
206 Chunking(ChunkingError),
207 Embeddings(EmbeddingError),
208 DbEngine(LanceDbError),
209}
210
211impl fmt::Display for IngestDocumentError {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 match self {
214 Self::NotInitialized => write!(f, "RKit must be initialized before ingesting"),
215 Self::EmptyContent => write!(f, "document content must not be empty"),
216 Self::FileRead { path, source } => {
217 write!(
218 f,
219 "failed to read document file {}: {source}",
220 path.display()
221 )
222 }
223 Self::InvalidGlobPattern(error) => {
224 write!(f, "invalid document file glob pattern: {error}")
225 }
226 Self::Glob(error) => write!(f, "failed to resolve document file glob: {error}"),
227 Self::NoFilesMatched { pattern } => {
228 write!(f, "document file glob matched no files: {pattern}")
229 }
230 Self::Chunking(error) => write!(f, "failed to chunk document: {error}"),
231 Self::Embeddings(error) => write!(f, "failed to generate embeddings: {error}"),
232 Self::DbEngine(error) => write!(f, "failed to insert document into database: {error}"),
233 }
234 }
235}
236
237impl Error for IngestDocumentError {
238 fn source(&self) -> Option<&(dyn Error + 'static)> {
239 match self {
240 Self::NotInitialized | Self::EmptyContent | Self::NoFilesMatched { .. } => None,
241 Self::FileRead { source, .. } => Some(source),
242 Self::InvalidGlobPattern(error) => Some(error),
243 Self::Glob(error) => Some(error),
244 Self::Chunking(error) => Some(error),
245 Self::Embeddings(error) => Some(error),
246 Self::DbEngine(error) => Some(error),
247 }
248 }
249}
250
251#[derive(Debug)]
252pub enum DocumentError {
253 NotInitialized,
254 DbEngine(LanceDbError),
255}
256
257impl fmt::Display for DocumentError {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 match self {
260 Self::NotInitialized => write!(f, "RKit must be initialized before reading documents"),
261 Self::DbEngine(error) => write!(f, "failed to read documents from database: {error}"),
262 }
263 }
264}
265
266impl Error for DocumentError {
267 fn source(&self) -> Option<&(dyn Error + 'static)> {
268 match self {
269 Self::NotInitialized => None,
270 Self::DbEngine(error) => Some(error),
271 }
272 }
273}
274
275#[derive(Debug)]
276pub enum VectorSearchError {
277 NotInitialized,
278 EmptyQuery,
279 Embeddings(EmbeddingError),
280 DbEngine(LanceDbError),
281}
282
283impl fmt::Display for VectorSearchError {
284 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285 match self {
286 Self::NotInitialized => write!(f, "RKit must be initialized before searching"),
287 Self::EmptyQuery => write!(f, "search query must not be empty"),
288 Self::Embeddings(error) => write!(f, "failed to generate query embedding: {error}"),
289 Self::DbEngine(error) => write!(f, "failed to search database: {error}"),
290 }
291 }
292}
293
294impl Error for VectorSearchError {
295 fn source(&self) -> Option<&(dyn Error + 'static)> {
296 match self {
297 Self::NotInitialized | Self::EmptyQuery => None,
298 Self::Embeddings(error) => Some(error),
299 Self::DbEngine(error) => Some(error),
300 }
301 }
302}
303
304#[derive(Debug)]
305pub enum KeywordSearchError {
306 NotInitialized,
307 EmptyQuery,
308 DbEngine(LanceDbError),
309}
310
311impl fmt::Display for KeywordSearchError {
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313 match self {
314 Self::NotInitialized => write!(f, "RKit must be initialized before searching"),
315 Self::EmptyQuery => write!(f, "search query must not be empty"),
316 Self::DbEngine(error) => write!(f, "failed to search database: {error}"),
317 }
318 }
319}
320
321impl Error for KeywordSearchError {
322 fn source(&self) -> Option<&(dyn Error + 'static)> {
323 match self {
324 Self::NotInitialized | Self::EmptyQuery => None,
325 Self::DbEngine(error) => Some(error),
326 }
327 }
328}
329
330#[derive(Debug)]
331pub enum InvokeToolError {
332 UnknownTool(String),
333 InvalidArguments {
334 tool_name: String,
335 source: serde_json::Error,
336 },
337 Serialization(serde_json::Error),
338 SemanticSearch(VectorSearchError),
339 KeywordSearch(KeywordSearchError),
340 Document(DocumentError),
341}
342
343impl fmt::Display for InvokeToolError {
344 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
345 match self {
346 Self::UnknownTool(name) => write!(f, "unknown tool: {name}"),
347 Self::InvalidArguments { tool_name, source } => {
348 write!(f, "invalid arguments for tool {tool_name}: {source}")
349 }
350 Self::Serialization(error) => write!(f, "failed to serialize tool result: {error}"),
351 Self::SemanticSearch(error) => write!(f, "semantic search tool failed: {error}"),
352 Self::KeywordSearch(error) => write!(f, "keyword search tool failed: {error}"),
353 Self::Document(error) => write!(f, "document tool failed: {error}"),
354 }
355 }
356}
357
358impl Error for InvokeToolError {
359 fn source(&self) -> Option<&(dyn Error + 'static)> {
360 match self {
361 Self::UnknownTool(_) => None,
362 Self::InvalidArguments { source, .. } => Some(source),
363 Self::Serialization(error) => Some(error),
364 Self::SemanticSearch(error) => Some(error),
365 Self::KeywordSearch(error) => Some(error),
366 Self::Document(error) => Some(error),
367 }
368 }
369}
370
371pub struct RKit {
372 db_engine_config: DbEngine,
373 embeddings_provider_config: EmbeddingsProviderKind,
374 chunking_config: ChunkingConfig,
375 db_engine: Option<InitializedDbEngine>,
376 embeddings_provider: Option<InitializedEmbeddingsProvider>,
377}
378
379impl RKit {
380 pub fn new(
381 db_engine: DbEngine,
382 embeddings_provider: EmbeddingsProviderKind,
383 ) -> Result<Self, RKitConfigError> {
384 validate_db_engine(&db_engine)?;
385
386 Ok(Self {
387 db_engine_config: db_engine,
388 embeddings_provider_config: embeddings_provider,
389 chunking_config: ChunkingConfig::default(),
390 db_engine: None,
391 embeddings_provider: None,
392 })
393 }
394
395 pub async fn init(&mut self) -> Result<(), RKitInitError> {
396 let initialized_db_engine = init_db_engine(self.db_engine_config.clone()).await?;
397 let initialized_embeddings_provider =
398 init_embeddings_provider(self.embeddings_provider_config.clone())?;
399 validate_embedding_dimensions(&initialized_db_engine, &initialized_embeddings_provider)?;
400 ensure_db_engine_tables(&initialized_db_engine).await?;
401
402 self.db_engine = Some(initialized_db_engine);
403 self.embeddings_provider = Some(initialized_embeddings_provider);
404
405 Ok(())
406 }
407
408 pub fn is_initialized(&self) -> bool {
409 self.db_engine.is_some() && self.embeddings_provider.is_some()
410 }
411
412 pub fn db_engine_initialized(&self) -> bool {
413 self.db_engine.is_some()
414 }
415
416 pub fn embeddings_provider_initialized(&self) -> bool {
417 self.embeddings_provider.is_some()
418 }
419
420 pub fn lancedb_backend(&self) -> Option<&LanceDbBackend> {
421 match self.db_engine.as_ref() {
422 Some(InitializedDbEngine::LanceDb(backend)) => Some(backend),
423 None => None,
424 }
425 }
426
427 pub fn ort_embedder(&self) -> Option<&embeddings::OrtEmbedder> {
428 match self.embeddings_provider.as_ref() {
429 Some(InitializedEmbeddingsProvider::Ort(embedder)) => Some(embedder),
430 #[cfg(test)]
431 Some(InitializedEmbeddingsProvider::Mock(_)) => None,
432 None => None,
433 }
434 }
435
436 pub fn ort_embedder_mut(&mut self) -> Option<&mut embeddings::OrtEmbedder> {
437 match self.embeddings_provider.as_mut() {
438 Some(InitializedEmbeddingsProvider::Ort(embedder)) => Some(embedder),
439 #[cfg(test)]
440 Some(InitializedEmbeddingsProvider::Mock(_)) => None,
441 None => None,
442 }
443 }
444
445 pub fn clear_initialized_state(&mut self) {
446 self.db_engine = None;
447 self.embeddings_provider = None;
448 }
449
450 pub fn db_engine_config(&self) -> &DbEngine {
451 &self.db_engine_config
452 }
453
454 pub fn embeddings_provider_config(&self) -> &EmbeddingsProviderKind {
455 &self.embeddings_provider_config
456 }
457
458 pub fn chunking_config(&self) -> ChunkingConfig {
459 self.chunking_config
460 }
461
462 pub fn set_chunking_config(
463 &mut self,
464 chunking_config: ChunkingConfig,
465 ) -> Result<(), RKitConfigError> {
466 validate_chunking_config(chunking_config)?;
467 self.chunking_config = chunking_config;
468 Ok(())
469 }
470
471 pub fn select_db_engine(&mut self, db_engine: DbEngine) -> Result<(), RKitConfigError> {
472 validate_db_engine(&db_engine)?;
473 self.db_engine_config = db_engine;
474 self.db_engine = None;
475 Ok(())
476 }
477
478 pub fn register_embeddings_provider(
479 &mut self,
480 embeddings_provider: EmbeddingsProviderKind,
481 ) -> Result<(), RKitConfigError> {
482 self.embeddings_provider_config = embeddings_provider;
483 self.embeddings_provider = None;
484 Ok(())
485 }
486
487 pub async fn ingest_document(
488 &mut self,
489 content: String,
490 ) -> Result<IngestDocumentResult, IngestDocumentError> {
491 let mut results = self.ingest_documents(vec![content]).await?;
492 Ok(results
493 .pop()
494 .expect("single-document ingest should always yield one result"))
495 }
496
497 pub async fn ingest_document_file<P: AsRef<Path>>(
498 &mut self,
499 path: P,
500 ) -> Result<IngestDocumentResult, IngestDocumentError> {
501 let content = read_document_file(path.as_ref())?;
502 self.ingest_document(content).await
503 }
504
505 pub async fn upsert_document(
506 &mut self,
507 id: String,
508 content: String,
509 ) -> Result<IngestDocumentResult, IngestDocumentError> {
510 let document = self.prepare_document_with_id(id, content)?;
511 let db_engine = self
512 .db_engine
513 .as_ref()
514 .ok_or(IngestDocumentError::NotInitialized)?;
515 let embeddings_provider = self
516 .embeddings_provider
517 .as_ref()
518 .ok_or(IngestDocumentError::NotInitialized)?;
519 let embeddings = embeddings_provider
520 .embed_batch(&document.chunks)
521 .await
522 .map_err(IngestDocumentError::Embeddings)?;
523
524 let mut chunks = Vec::with_capacity(document.chunks.len());
525 let mut embedding_iter = embeddings.into_iter();
526 let chunk_count = document.chunks.len();
527
528 for (chunk_index, text) in document.chunks.into_iter().enumerate() {
529 let vector = embedding_iter.next().ok_or_else(|| {
530 IngestDocumentError::Embeddings(EmbeddingError::MissingOutput(
531 "fewer embeddings returned than chunks provided".to_string(),
532 ))
533 })?;
534 chunks.push(Chunk {
535 document_id: document.document_id.clone(),
536 chunk_index: chunk_index as u64,
537 text,
538 vector,
539 });
540 }
541
542 if embedding_iter.next().is_some() {
543 return Err(IngestDocumentError::Embeddings(
544 EmbeddingError::MissingOutput(
545 "more embeddings returned than chunks provided".to_string(),
546 ),
547 ));
548 }
549
550 match db_engine {
551 InitializedDbEngine::LanceDb(backend) => backend
552 .upsert_data(
553 &DocumentRecord {
554 document_id: document.document_id.clone(),
555 content: document.content,
556 },
557 &chunks,
558 )
559 .await
560 .map_err(IngestDocumentError::DbEngine)?,
561 }
562
563 Ok(IngestDocumentResult {
564 document_id: document.document_id,
565 chunk_count,
566 })
567 }
568
569 pub async fn ingest_documents(
570 &mut self,
571 documents: Vec<String>,
572 ) -> Result<Vec<IngestDocumentResult>, IngestDocumentError> {
573 if documents.is_empty() {
574 return Err(IngestDocumentError::EmptyContent);
575 }
576
577 let prepared_documents = documents
578 .into_iter()
579 .map(|content| self.prepare_document(content))
580 .collect::<Result<Vec<_>, _>>()?;
581 let flattened_chunks = prepared_documents
582 .iter()
583 .flat_map(|document| document.chunks.iter().cloned())
584 .collect::<Vec<_>>();
585
586 let db_engine = self
587 .db_engine
588 .as_ref()
589 .ok_or(IngestDocumentError::NotInitialized)?;
590 let embeddings_provider = self
591 .embeddings_provider
592 .as_ref()
593 .ok_or(IngestDocumentError::NotInitialized)?;
594 let embeddings = embeddings_provider
595 .embed_batch(&flattened_chunks)
596 .await
597 .map_err(IngestDocumentError::Embeddings)?;
598
599 let mut documents = Vec::with_capacity(prepared_documents.len());
600 let mut chunks = Vec::with_capacity(flattened_chunks.len());
601 let mut embedding_iter = embeddings.into_iter();
602 let mut results = Vec::with_capacity(prepared_documents.len());
603
604 for document in prepared_documents {
605 let chunk_count = document.chunks.len();
606
607 for (chunk_index, text) in document.chunks.into_iter().enumerate() {
608 let vector = embedding_iter.next().ok_or_else(|| {
609 IngestDocumentError::Embeddings(EmbeddingError::MissingOutput(
610 "fewer embeddings returned than chunks provided".to_string(),
611 ))
612 })?;
613 chunks.push(Chunk {
614 document_id: document.document_id.clone(),
615 chunk_index: chunk_index as u64,
616 text,
617 vector,
618 });
619 }
620
621 documents.push(DocumentRecord {
622 document_id: document.document_id.clone(),
623 content: document.content,
624 });
625 results.push(IngestDocumentResult {
626 document_id: document.document_id,
627 chunk_count,
628 });
629 }
630
631 if embedding_iter.next().is_some() {
632 return Err(IngestDocumentError::Embeddings(
633 EmbeddingError::MissingOutput(
634 "more embeddings returned than chunks provided".to_string(),
635 ),
636 ));
637 }
638
639 match db_engine {
640 InitializedDbEngine::LanceDb(backend) => backend
641 .insert_data(&documents, &chunks)
642 .await
643 .map_err(IngestDocumentError::DbEngine)?,
644 }
645
646 Ok(results)
647 }
648
649 pub async fn ingest_document_files(
650 &mut self,
651 pattern: &str,
652 ) -> Result<Vec<IngestDocumentResult>, IngestDocumentError> {
653 let paths = glob(pattern)
654 .map_err(IngestDocumentError::InvalidGlobPattern)?
655 .collect::<Result<Vec<_>, _>>()
656 .map_err(IngestDocumentError::Glob)?;
657
658 if paths.is_empty() {
659 return Err(IngestDocumentError::NoFilesMatched {
660 pattern: pattern.to_string(),
661 });
662 }
663
664 let documents = paths
665 .into_iter()
666 .map(|path| read_document_file(&path))
667 .collect::<Result<Vec<_>, _>>()?;
668
669 self.ingest_documents(documents).await
670 }
671
672 pub async fn vector_search(
673 &self,
674 query: String,
675 limit: usize,
676 ) -> Result<Vec<VectorSearchResult>, VectorSearchError> {
677 if query.trim().is_empty() {
678 return Err(VectorSearchError::EmptyQuery);
679 }
680 if !self.is_initialized() {
681 return Err(VectorSearchError::NotInitialized);
682 }
683 if limit == 0 {
684 return Ok(Vec::new());
685 }
686
687 let embeddings_provider = self
688 .embeddings_provider
689 .as_ref()
690 .expect("initialized embeddings provider");
691 let mut embeddings = embeddings_provider
692 .embed_batch(&[query])
693 .await
694 .map_err(VectorSearchError::Embeddings)?;
695 let query_vector = embeddings.pop().ok_or_else(|| {
696 VectorSearchError::Embeddings(EmbeddingError::MissingOutput(
697 "no query embedding returned".to_string(),
698 ))
699 })?;
700
701 if embeddings.pop().is_some() {
702 return Err(VectorSearchError::Embeddings(
703 EmbeddingError::MissingOutput("more than one query embedding returned".to_string()),
704 ));
705 }
706
707 let db_engine = self.db_engine.as_ref().expect("initialized db engine");
708 match db_engine {
709 InitializedDbEngine::LanceDb(backend) => backend
710 .vector_search(query_vector, limit)
711 .await
712 .map(|results| {
713 results
714 .into_iter()
715 .map(|result| VectorSearchResult {
716 document_id: result.document_id,
717 text: result.text,
718 distance: result.distance,
719 })
720 .collect()
721 })
722 .map_err(VectorSearchError::DbEngine),
723 }
724 }
725
726 pub async fn keyword_search(
727 &self,
728 query: String,
729 limit: usize,
730 ) -> Result<Vec<KeywordSearchResult>, KeywordSearchError> {
731 if query.trim().is_empty() {
732 return Err(KeywordSearchError::EmptyQuery);
733 }
734 if !self.is_initialized() {
735 return Err(KeywordSearchError::NotInitialized);
736 }
737 if limit == 0 {
738 return Ok(Vec::new());
739 }
740
741 let db_engine = self.db_engine.as_ref().expect("initialized db engine");
742 match db_engine {
743 InitializedDbEngine::LanceDb(backend) => backend
744 .keyword_search(query, limit)
745 .await
746 .map(|results| {
747 results
748 .into_iter()
749 .map(|result| KeywordSearchResult {
750 document_id: result.document_id,
751 text: result.text,
752 score: result.score,
753 })
754 .collect()
755 })
756 .map_err(KeywordSearchError::DbEngine),
757 }
758 }
759
760 pub async fn list_documents(&self) -> Result<Vec<DocumentSummary>, DocumentError> {
761 if !self.is_initialized() {
762 return Err(DocumentError::NotInitialized);
763 }
764 let db_engine = self.db_engine.as_ref().expect("initialized db engine");
765
766 match db_engine {
767 InitializedDbEngine::LanceDb(backend) => backend
768 .list_documents()
769 .await
770 .map(|documents| {
771 documents
772 .into_iter()
773 .map(|document| DocumentSummary {
774 document_id: document.document_id,
775 })
776 .collect()
777 })
778 .map_err(DocumentError::DbEngine),
779 }
780 }
781
782 pub async fn get_document(&self, id: String) -> Result<Option<Document>, DocumentError> {
783 if !self.is_initialized() {
784 return Err(DocumentError::NotInitialized);
785 }
786 let db_engine = self.db_engine.as_ref().expect("initialized db engine");
787
788 match db_engine {
789 InitializedDbEngine::LanceDb(backend) => backend
790 .get_document(&id)
791 .await
792 .map(|document| {
793 document.map(|document| Document {
794 document_id: document.document_id,
795 content: document.content,
796 })
797 })
798 .map_err(DocumentError::DbEngine),
799 }
800 }
801
802 pub async fn delete_document(&self, id: String) -> Result<(), DocumentError> {
803 if !self.is_initialized() {
804 return Err(DocumentError::NotInitialized);
805 }
806 let db_engine = self.db_engine.as_ref().expect("initialized db engine");
807
808 match db_engine {
809 InitializedDbEngine::LanceDb(backend) => backend
810 .delete_document(&id)
811 .await
812 .map_err(DocumentError::DbEngine),
813 }
814 }
815
816 pub fn get_tool_definitions(
817 &self,
818 descriptions: Option<ToolDescriptions>,
819 ) -> Vec<ToolDefinition> {
820 let descriptions = descriptions.unwrap_or_default();
821
822 vec![
823 ToolDefinition {
824 name: SEMANTIC_SEARCH_TOOL_NAME.to_string(),
825 description: Some(descriptions.semantic_search.unwrap_or_else(|| {
826 "Search ingested documents by semantic similarity to a natural language query."
827 .to_string()
828 })),
829 input_schema: search_tool_input_schema(),
830 },
831 ToolDefinition {
832 name: KEYWORD_SEARCH_TOOL_NAME.to_string(),
833 description: Some(descriptions.keyword_search.unwrap_or_else(|| {
834 "Search ingested documents by exact keyword and full-text relevance."
835 .to_string()
836 })),
837 input_schema: search_tool_input_schema(),
838 },
839 ToolDefinition {
840 name: LIST_DOCUMENTS_TOOL_NAME.to_string(),
841 description: Some(descriptions.list_documents.unwrap_or_else(|| {
842 "List document identifiers currently stored in the retrieval index.".to_string()
843 })),
844 input_schema: empty_tool_input_schema(),
845 },
846 ToolDefinition {
847 name: GET_DOCUMENT_TOOL_NAME.to_string(),
848 description: Some(descriptions.get_document.unwrap_or_else(|| {
849 "Fetch the full stored content for a document by document identifier."
850 .to_string()
851 })),
852 input_schema: get_document_tool_input_schema(),
853 },
854 ]
855 }
856
857 pub async fn invoke_tool(
858 &self,
859 name: &str,
860 arguments: Value,
861 ) -> Result<Value, InvokeToolError> {
862 match name {
863 SEMANTIC_SEARCH_TOOL_NAME => {
864 let arguments: SearchToolArguments = parse_tool_arguments(name, arguments)?;
865 let results = self
866 .vector_search(
867 arguments.query,
868 arguments.limit.unwrap_or(DEFAULT_TOOL_LIMIT),
869 )
870 .await
871 .map_err(InvokeToolError::SemanticSearch)?;
872 serialize_tool_result(json!({ "results": results }))
873 }
874 KEYWORD_SEARCH_TOOL_NAME => {
875 let arguments: SearchToolArguments = parse_tool_arguments(name, arguments)?;
876 let results = self
877 .keyword_search(
878 arguments.query,
879 arguments.limit.unwrap_or(DEFAULT_TOOL_LIMIT),
880 )
881 .await
882 .map_err(InvokeToolError::KeywordSearch)?;
883 serialize_tool_result(json!({ "results": results }))
884 }
885 LIST_DOCUMENTS_TOOL_NAME => {
886 let _: EmptyToolArguments = parse_tool_arguments(name, arguments)?;
887 let documents = self
888 .list_documents()
889 .await
890 .map_err(InvokeToolError::Document)?;
891 serialize_tool_result(json!({ "documents": documents }))
892 }
893 GET_DOCUMENT_TOOL_NAME => {
894 let arguments: GetDocumentToolArguments = parse_tool_arguments(name, arguments)?;
895 let document = self
896 .get_document(arguments.document_id)
897 .await
898 .map_err(InvokeToolError::Document)?;
899 serialize_tool_result(json!({ "document": document }))
900 }
901 other => Err(InvokeToolError::UnknownTool(other.to_string())),
902 }
903 }
904
905 fn prepare_document(&self, content: String) -> Result<PreparedDocument, IngestDocumentError> {
906 self.prepare_document_with_id(Uuid::new_v4().to_string(), content)
907 }
908
909 fn prepare_document_with_id(
910 &self,
911 document_id: String,
912 content: String,
913 ) -> Result<PreparedDocument, IngestDocumentError> {
914 if content.trim().is_empty() {
915 return Err(IngestDocumentError::EmptyContent);
916 }
917
918 let chunks = self
919 .chunk_document_content(&content)
920 .map_err(IngestDocumentError::Chunking)?;
921 if chunks.is_empty() {
922 return Err(IngestDocumentError::EmptyContent);
923 }
924
925 Ok(PreparedDocument {
926 document_id,
927 content,
928 chunks,
929 })
930 }
931
932 fn chunk_document_content(&self, content: &str) -> Result<Vec<String>, ChunkingError> {
933 match self.embeddings_provider.as_ref() {
934 Some(InitializedEmbeddingsProvider::Ort(embedder)) => embedder
935 .chunk_text(content, self.chunking_config.overlap_size)
936 .map_err(|_| ChunkingError::EmbeddingTokenizer),
937 #[cfg(test)]
938 Some(InitializedEmbeddingsProvider::Mock(_)) | None => {
939 chunk_text(content, self.chunking_config)
940 }
941 #[cfg(not(test))]
942 None => chunk_text(content, self.chunking_config),
943 }
944 }
945}
946
947enum InitializedDbEngine {
948 LanceDb(LanceDbBackend),
949}
950
951#[cfg(test)]
952struct MockEmbedder {
953 dimensions: usize,
954 calls: Arc<AtomicUsize>,
955}
956
957#[cfg(test)]
958impl MockEmbedder {
959 fn new(dimensions: usize) -> Self {
960 Self {
961 dimensions,
962 calls: Arc::new(AtomicUsize::new(0)),
963 }
964 }
965
966 fn call_count(&self) -> usize {
967 self.calls.load(Ordering::SeqCst)
968 }
969
970 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
971 let call_number = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
972 Ok(texts
973 .iter()
974 .enumerate()
975 .map(|(index, _)| vec![(call_number + index) as f32; self.dimensions])
976 .collect())
977 }
978}
979
980enum InitializedEmbeddingsProvider {
981 Ort(embeddings::OrtEmbedder),
982 #[cfg(test)]
983 Mock(MockEmbedder),
984}
985
986impl InitializedEmbeddingsProvider {
987 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
988 match self {
989 Self::Ort(embedder) => embedder.embed_batch_shared(texts).await,
990 #[cfg(test)]
991 Self::Mock(embedder) => embedder.embed_batch(texts).await,
992 }
993 }
994
995 fn expected_embedding_size(&self) -> Option<usize> {
996 match self {
997 Self::Ort(embedder) => embedder.expected_embedding_size(),
998 #[cfg(test)]
999 Self::Mock(embedder) => Some(embedder.dimensions),
1000 }
1001 }
1002}
1003
1004struct PreparedDocument {
1005 document_id: String,
1006 content: String,
1007 chunks: Vec<String>,
1008}
1009
1010fn read_document_file(path: &Path) -> Result<String, IngestDocumentError> {
1011 std::fs::read_to_string(path).map_err(|source| IngestDocumentError::FileRead {
1012 path: path.to_path_buf(),
1013 source,
1014 })
1015}
1016
1017async fn init_db_engine(db_engine: DbEngine) -> Result<InitializedDbEngine, RKitInitError> {
1018 match db_engine {
1019 DbEngine::LanceDb {
1020 path,
1021 vector_dimensions,
1022 } => LanceDbBackend::new(path, vector_dimensions)
1023 .await
1024 .map(InitializedDbEngine::LanceDb)
1025 .map_err(RKitInitError::DbEngine),
1026 }
1027}
1028
1029fn init_embeddings_provider(
1030 embeddings_provider: EmbeddingsProviderKind,
1031) -> Result<InitializedEmbeddingsProvider, RKitInitError> {
1032 match embeddings_provider {
1033 EmbeddingsProviderKind::Ort(config) => OrtEmbedder::new(config)
1034 .map(InitializedEmbeddingsProvider::Ort)
1035 .map_err(RKitInitError::EmbeddingsProvider),
1036 }
1037}
1038
1039fn validate_db_engine(db_engine: &DbEngine) -> Result<(), RKitConfigError> {
1040 match db_engine {
1041 DbEngine::LanceDb {
1042 vector_dimensions, ..
1043 } if *vector_dimensions <= 0 => {
1044 Err(RKitConfigError::InvalidVectorDimensions(*vector_dimensions))
1045 }
1046 DbEngine::LanceDb { .. } => Ok(()),
1047 }
1048}
1049
1050fn validate_chunking_config(chunking_config: ChunkingConfig) -> Result<(), RKitConfigError> {
1051 chunk_text("validation", chunking_config)
1052 .map(|_| ())
1053 .map_err(RKitConfigError::InvalidChunkingConfig)
1054}
1055
1056async fn ensure_db_engine_tables(db_engine: &InitializedDbEngine) -> Result<(), RKitInitError> {
1057 match db_engine {
1058 InitializedDbEngine::LanceDb(backend) => backend
1059 .create_tables()
1060 .await
1061 .map(|_| ())
1062 .map_err(RKitInitError::DbEngine),
1063 }
1064}
1065
1066fn validate_embedding_dimensions(
1067 db_engine: &InitializedDbEngine,
1068 embeddings_provider: &InitializedEmbeddingsProvider,
1069) -> Result<(), RKitInitError> {
1070 let Some(embedding_dimensions) = embeddings_provider.expected_embedding_size() else {
1071 return Ok(());
1072 };
1073
1074 match db_engine {
1075 InitializedDbEngine::LanceDb(backend)
1076 if backend.vector_dimensions() as usize != embedding_dimensions =>
1077 {
1078 Err(RKitInitError::EmbeddingDimensionMismatch {
1079 db_vector_dimensions: backend.vector_dimensions(),
1080 embedding_dimensions,
1081 })
1082 }
1083 InitializedDbEngine::LanceDb(_) => Ok(()),
1084 }
1085}
1086
1087#[derive(Debug, Deserialize)]
1088#[serde(deny_unknown_fields)]
1089struct SearchToolArguments {
1090 query: String,
1091 limit: Option<usize>,
1092}
1093
1094#[derive(Debug, Deserialize)]
1095#[serde(deny_unknown_fields)]
1096struct EmptyToolArguments {}
1097
1098#[derive(Debug, Deserialize)]
1099#[serde(deny_unknown_fields)]
1100struct GetDocumentToolArguments {
1101 document_id: String,
1102}
1103
1104fn parse_tool_arguments<T: DeserializeOwned>(
1105 tool_name: &str,
1106 arguments: Value,
1107) -> Result<T, InvokeToolError> {
1108 serde_json::from_value(arguments).map_err(|source| InvokeToolError::InvalidArguments {
1109 tool_name: tool_name.to_string(),
1110 source,
1111 })
1112}
1113
1114fn serialize_tool_result<T: Serialize>(value: T) -> Result<Value, InvokeToolError> {
1115 serde_json::to_value(value).map_err(InvokeToolError::Serialization)
1116}
1117
1118fn search_tool_input_schema() -> Value {
1119 json!({
1120 "type": "object",
1121 "properties": {
1122 "query": {
1123 "type": "string",
1124 "description": "The search query."
1125 },
1126 "limit": {
1127 "type": "integer",
1128 "minimum": 0,
1129 "description": "Maximum number of matching chunks to return. Defaults to 10."
1130 }
1131 },
1132 "required": ["query"],
1133 "additionalProperties": false
1134 })
1135}
1136
1137fn empty_tool_input_schema() -> Value {
1138 json!({
1139 "type": "object",
1140 "properties": {},
1141 "additionalProperties": false
1142 })
1143}
1144
1145fn get_document_tool_input_schema() -> Value {
1146 json!({
1147 "type": "object",
1148 "properties": {
1149 "document_id": {
1150 "type": "string",
1151 "description": "The document identifier returned by ingestion or list_documents."
1152 }
1153 },
1154 "required": ["document_id"],
1155 "additionalProperties": false
1156 })
1157}
1158
1159#[cfg(test)]
1160mod tests {
1161 use super::{
1162 ChunkingConfig, DbEngine, Document, DocumentError, DocumentSummary, EmbeddingError,
1163 EmbeddingsConfig, EmbeddingsProviderKind, IngestDocumentError,
1164 InitializedEmbeddingsProvider, InvokeToolError, KeywordSearchError, KeywordSearchResult,
1165 MockEmbedder, RKit, RKitConfigError, RKitInitError, ToolDescriptions, VectorSearchError,
1166 VectorSearchResult,
1167 };
1168 use arrow_array::{Array, FixedSizeListArray, Float32Array, StringArray};
1169 use futures::TryStreamExt;
1170 use lancedb::query::ExecutableQuery;
1171 use serde_json::json;
1172 use std::fs;
1173 use std::path::PathBuf;
1174 use tempfile::{TempDir, tempdir};
1175
1176 fn demo_engine(path: &str, vector_dimensions: i32) -> DbEngine {
1177 DbEngine::LanceDb {
1178 path: PathBuf::from(path),
1179 vector_dimensions,
1180 }
1181 }
1182
1183 fn missing_local_ort_provider() -> EmbeddingsProviderKind {
1184 EmbeddingsProviderKind::Ort(EmbeddingsConfig {
1185 local_model_path: Some(PathBuf::from("/tmp/retrieval-kit-missing-model.onnx")),
1186 local_tokenizer_path: Some(PathBuf::from("/tmp/retrieval-kit-missing-tokenizer.json")),
1187 local_pooling_config_path: Some(PathBuf::from(
1188 "/tmp/retrieval-kit-missing-pooling-config.json",
1189 )),
1190 local_transformer_config_path: Some(PathBuf::from(
1191 "/tmp/retrieval-kit-missing-transformer-config.json",
1192 )),
1193 ..EmbeddingsConfig::default()
1194 })
1195 }
1196
1197 async fn table_document_ids(rkit: &RKit) -> Vec<String> {
1198 let backend = rkit.lancedb_backend().unwrap();
1199 let table = backend
1200 .connection()
1201 .open_table("chunks")
1202 .execute()
1203 .await
1204 .unwrap();
1205 let rows = table.query().execute().await.unwrap();
1206 let batches = rows.try_collect::<Vec<_>>().await.unwrap();
1207
1208 batches
1209 .iter()
1210 .flat_map(|batch| {
1211 batch
1212 .column_by_name("document_id")
1213 .unwrap()
1214 .as_any()
1215 .downcast_ref::<StringArray>()
1216 .unwrap()
1217 .iter()
1218 .flatten()
1219 .map(str::to_owned)
1220 .collect::<Vec<_>>()
1221 })
1222 .collect()
1223 }
1224
1225 async fn table_texts(rkit: &RKit) -> Vec<String> {
1226 let backend = rkit.lancedb_backend().unwrap();
1227 let table = backend
1228 .connection()
1229 .open_table("chunks")
1230 .execute()
1231 .await
1232 .unwrap();
1233 let rows = table.query().execute().await.unwrap();
1234 let batches = rows.try_collect::<Vec<_>>().await.unwrap();
1235
1236 batches
1237 .iter()
1238 .flat_map(|batch| {
1239 batch
1240 .column_by_name("text")
1241 .unwrap()
1242 .as_any()
1243 .downcast_ref::<StringArray>()
1244 .unwrap()
1245 .iter()
1246 .flatten()
1247 .map(str::to_owned)
1248 .collect::<Vec<_>>()
1249 })
1250 .collect()
1251 }
1252
1253 async fn table_vectors(rkit: &RKit) -> Vec<Vec<f32>> {
1254 let backend = rkit.lancedb_backend().unwrap();
1255 let table = backend
1256 .connection()
1257 .open_table("chunks")
1258 .execute()
1259 .await
1260 .unwrap();
1261 let rows = table.query().execute().await.unwrap();
1262 let batches = rows.try_collect::<Vec<_>>().await.unwrap();
1263
1264 batches
1265 .iter()
1266 .flat_map(|batch| {
1267 let vectors = batch
1268 .column_by_name("vector")
1269 .unwrap()
1270 .as_any()
1271 .downcast_ref::<FixedSizeListArray>()
1272 .unwrap();
1273
1274 (0..vectors.len())
1275 .map(|index| {
1276 vectors
1277 .value(index)
1278 .as_any()
1279 .downcast_ref::<Float32Array>()
1280 .unwrap()
1281 .values()
1282 .to_vec()
1283 })
1284 .collect::<Vec<_>>()
1285 })
1286 .collect()
1287 }
1288
1289 async fn table_stored_documents(rkit: &RKit) -> Vec<Document> {
1290 let backend = rkit.lancedb_backend().unwrap();
1291 let mut documents = backend
1292 .list_documents()
1293 .await
1294 .unwrap()
1295 .into_iter()
1296 .map(|document| Document {
1297 document_id: document.document_id,
1298 content: document.content,
1299 })
1300 .collect::<Vec<_>>();
1301 documents.sort_by(|left, right| left.document_id.cmp(&right.document_id));
1302 documents
1303 }
1304
1305 async fn initialized_test_rkit(vector_dimensions: i32) -> (TempDir, RKit) {
1306 let temp_dir = tempdir().unwrap();
1307 let mut rkit = RKit::new(
1308 demo_engine(temp_dir.path().to_str().unwrap(), vector_dimensions),
1309 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1310 )
1311 .unwrap();
1312 rkit.db_engine = Some(
1313 super::init_db_engine(demo_engine(
1314 temp_dir.path().to_str().unwrap(),
1315 vector_dimensions,
1316 ))
1317 .await
1318 .unwrap(),
1319 );
1320 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
1321 .await
1322 .unwrap();
1323 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(
1324 vector_dimensions as usize,
1325 )));
1326 (temp_dir, rkit)
1327 }
1328
1329 #[test]
1330 fn get_tool_definitions_returns_default_mcp_compatible_definitions() {
1331 let rkit = RKit::new(
1332 demo_engine("/tmp/rkit-a", 384),
1333 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1334 )
1335 .unwrap();
1336
1337 let definitions = rkit.get_tool_definitions(None);
1338
1339 assert_eq!(definitions.len(), 4);
1340 assert_eq!(definitions[0].name, "semantic_search");
1341 assert_eq!(definitions[1].name, "keyword_search");
1342 assert_eq!(definitions[2].name, "list_documents");
1343 assert_eq!(definitions[3].name, "get_document");
1344 assert_eq!(
1345 definitions[0].description.as_deref(),
1346 Some("Search ingested documents by semantic similarity to a natural language query.")
1347 );
1348 assert_eq!(definitions[0].input_schema["required"], json!(["query"]));
1349 assert_eq!(
1350 definitions[0].input_schema["additionalProperties"],
1351 json!(false)
1352 );
1353 assert_eq!(definitions[2].input_schema["properties"], json!({}));
1354 assert_eq!(
1355 definitions[3].input_schema["required"],
1356 json!(["document_id"])
1357 );
1358 }
1359
1360 #[test]
1361 fn get_tool_definitions_applies_description_overrides_selectively() {
1362 let rkit = RKit::new(
1363 demo_engine("/tmp/rkit-a", 384),
1364 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1365 )
1366 .unwrap();
1367
1368 let definitions = rkit.get_tool_definitions(Some(ToolDescriptions {
1369 semantic_search: Some("Custom semantic search".to_string()),
1370 get_document: Some("Custom document fetch".to_string()),
1371 ..ToolDescriptions::default()
1372 }));
1373
1374 assert_eq!(
1375 definitions[0].description.as_deref(),
1376 Some("Custom semantic search")
1377 );
1378 assert_eq!(
1379 definitions[1].description.as_deref(),
1380 Some("Search ingested documents by exact keyword and full-text relevance.")
1381 );
1382 assert_eq!(
1383 definitions[3].description.as_deref(),
1384 Some("Custom document fetch")
1385 );
1386 }
1387
1388 #[tokio::test]
1389 async fn invoke_tool_dispatches_document_tools() {
1390 let (_temp_dir, mut rkit) = initialized_test_rkit(3).await;
1391 rkit.upsert_document("manual-doc".to_string(), "Document content.".to_string())
1392 .await
1393 .unwrap();
1394
1395 let list_result = rkit.invoke_tool("list_documents", json!({})).await.unwrap();
1396 let get_result = rkit
1397 .invoke_tool("get_document", json!({ "document_id": "manual-doc" }))
1398 .await
1399 .unwrap();
1400 let missing_result = rkit
1401 .invoke_tool("get_document", json!({ "document_id": "missing-doc" }))
1402 .await
1403 .unwrap();
1404
1405 assert_eq!(
1406 list_result,
1407 json!({ "documents": [{ "document_id": "manual-doc" }] })
1408 );
1409 assert_eq!(
1410 get_result,
1411 json!({
1412 "document": {
1413 "document_id": "manual-doc",
1414 "content": "Document content."
1415 }
1416 })
1417 );
1418 assert_eq!(missing_result, json!({ "document": null }));
1419 }
1420
1421 #[tokio::test]
1422 async fn invoke_tool_dispatches_search_tools() {
1423 let (_temp_dir, mut rkit) = initialized_test_rkit(3).await;
1424 rkit.upsert_document(
1425 "first-doc".to_string(),
1426 "First document content.".to_string(),
1427 )
1428 .await
1429 .unwrap();
1430 rkit.upsert_document(
1431 "keyword-doc".to_string(),
1432 "Rust search database.".to_string(),
1433 )
1434 .await
1435 .unwrap();
1436
1437 let semantic_result = rkit
1438 .invoke_tool(
1439 "semantic_search",
1440 json!({ "query": "content like the second document", "limit": 1 }),
1441 )
1442 .await
1443 .unwrap();
1444 let keyword_result = rkit
1445 .invoke_tool("keyword_search", json!({ "query": "rust", "limit": 1 }))
1446 .await
1447 .unwrap();
1448
1449 let semantic_results: Vec<VectorSearchResult> =
1450 serde_json::from_value(semantic_result["results"].clone()).unwrap();
1451 let keyword_results: Vec<KeywordSearchResult> =
1452 serde_json::from_value(keyword_result["results"].clone()).unwrap();
1453
1454 assert_eq!(semantic_results.len(), 1);
1455 assert_eq!(semantic_results[0].document_id, "keyword-doc");
1456 assert_eq!(keyword_results.len(), 1);
1457 assert_eq!(keyword_results[0].document_id, "keyword-doc");
1458 assert_eq!(keyword_results[0].text, "Rust search database.");
1459 }
1460
1461 #[tokio::test]
1462 async fn invoke_tool_uses_default_search_limit() {
1463 let (_temp_dir, mut rkit) = initialized_test_rkit(3).await;
1464 for index in 0..11 {
1465 rkit.upsert_document(format!("doc-{index:02}"), format!("Document {index}."))
1466 .await
1467 .unwrap();
1468 }
1469
1470 let result = rkit
1471 .invoke_tool("semantic_search", json!({ "query": "document" }))
1472 .await
1473 .unwrap();
1474
1475 assert_eq!(result["results"].as_array().unwrap().len(), 10);
1476 }
1477
1478 #[tokio::test]
1479 async fn invoke_tool_reports_bad_calls_as_errors() {
1480 let (_temp_dir, rkit) = initialized_test_rkit(3).await;
1481
1482 let unknown = rkit
1483 .invoke_tool("missing_tool", json!({}))
1484 .await
1485 .unwrap_err();
1486 let invalid = rkit
1487 .invoke_tool("semantic_search", json!({ "limit": 1 }))
1488 .await
1489 .unwrap_err();
1490 let extra_argument = rkit
1491 .invoke_tool("list_documents", json!({ "unexpected": true }))
1492 .await
1493 .unwrap_err();
1494
1495 assert!(matches!(unknown, InvokeToolError::UnknownTool(name) if name == "missing_tool"));
1496 assert!(matches!(
1497 invalid,
1498 InvokeToolError::InvalidArguments { tool_name, .. } if tool_name == "semantic_search"
1499 ));
1500 assert!(matches!(
1501 extra_argument,
1502 InvokeToolError::InvalidArguments { tool_name, .. } if tool_name == "list_documents"
1503 ));
1504 }
1505
1506 #[test]
1507 fn new_stores_both_configs() {
1508 let db_engine = demo_engine("/tmp/rkit-a", 384);
1509 let embeddings_provider = EmbeddingsProviderKind::Ort(EmbeddingsConfig::default());
1510
1511 let rkit = RKit::new(db_engine.clone(), embeddings_provider.clone()).unwrap();
1512
1513 assert_eq!(rkit.db_engine_config(), &db_engine);
1514 assert_eq!(rkit.embeddings_provider_config(), &embeddings_provider);
1515 assert_eq!(rkit.chunking_config(), ChunkingConfig::default());
1516 }
1517
1518 #[test]
1519 fn select_db_engine_accepts_valid_lancedb_config() {
1520 let mut rkit = RKit::new(
1521 demo_engine("/tmp/rkit-a", 384),
1522 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1523 )
1524 .unwrap();
1525 let next_engine = demo_engine("/tmp/rkit-b", 768);
1526
1527 rkit.select_db_engine(next_engine.clone()).unwrap();
1528
1529 assert_eq!(rkit.db_engine_config(), &next_engine);
1530 }
1531
1532 #[test]
1533 fn select_db_engine_rejects_non_positive_vector_dimensions() {
1534 let mut rkit = RKit::new(
1535 demo_engine("/tmp/rkit-a", 384),
1536 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1537 )
1538 .unwrap();
1539
1540 let error = rkit
1541 .select_db_engine(demo_engine("/tmp/rkit-b", 0))
1542 .unwrap_err();
1543
1544 assert_eq!(error, RKitConfigError::InvalidVectorDimensions(0));
1545 }
1546
1547 #[test]
1548 fn register_embeddings_provider_accepts_default_ort_config() {
1549 let mut rkit = RKit::new(
1550 demo_engine("/tmp/rkit-a", 384),
1551 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1552 )
1553 .unwrap();
1554 let provider = EmbeddingsProviderKind::Ort(EmbeddingsConfig::default());
1555
1556 rkit.register_embeddings_provider(provider.clone()).unwrap();
1557
1558 assert_eq!(rkit.embeddings_provider_config(), &provider);
1559 }
1560
1561 #[test]
1562 fn reselecting_provider_replaces_prior_config() {
1563 let mut rkit = RKit::new(
1564 demo_engine("/tmp/rkit-a", 384),
1565 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1566 )
1567 .unwrap();
1568 let provider = EmbeddingsProviderKind::Ort(EmbeddingsConfig {
1569 normalize: false,
1570 max_length: 64,
1571 ..EmbeddingsConfig::default()
1572 });
1573
1574 rkit.register_embeddings_provider(provider.clone()).unwrap();
1575
1576 assert_eq!(rkit.embeddings_provider_config(), &provider);
1577 }
1578
1579 #[test]
1580 fn set_chunking_config_validates_values() {
1581 let mut rkit = RKit::new(
1582 demo_engine("/tmp/rkit-a", 384),
1583 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1584 )
1585 .unwrap();
1586 let chunking_config = ChunkingConfig {
1587 chunk_size: 128,
1588 overlap_size: 16,
1589 };
1590
1591 rkit.set_chunking_config(chunking_config).unwrap();
1592
1593 assert_eq!(rkit.chunking_config(), chunking_config);
1594 }
1595
1596 #[test]
1597 fn set_chunking_config_rejects_invalid_values() {
1598 let mut rkit = RKit::new(
1599 demo_engine("/tmp/rkit-a", 384),
1600 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1601 )
1602 .unwrap();
1603
1604 let error = rkit
1605 .set_chunking_config(ChunkingConfig {
1606 chunk_size: 32,
1607 overlap_size: 32,
1608 })
1609 .unwrap_err();
1610
1611 assert!(matches!(error, RKitConfigError::InvalidChunkingConfig(_)));
1612 }
1613
1614 #[tokio::test]
1615 async fn init_db_engine_initializes_lancedb_backend() {
1616 let temp_dir = tempdir().unwrap();
1617 let initialized_db_engine =
1618 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 384))
1619 .await
1620 .unwrap();
1621
1622 super::ensure_db_engine_tables(&initialized_db_engine)
1623 .await
1624 .unwrap();
1625
1626 match initialized_db_engine {
1627 super::InitializedDbEngine::LanceDb(backend) => {
1628 assert_eq!(backend.vector_dimensions(), 384);
1629 let table = backend
1630 .connection()
1631 .open_table("chunks")
1632 .execute()
1633 .await
1634 .unwrap();
1635 let documents_table = backend
1636 .connection()
1637 .open_table("documents")
1638 .execute()
1639 .await
1640 .unwrap();
1641 assert_eq!(documents_table.count_rows(None).await.unwrap(), 0);
1642 assert_eq!(table.count_rows(None).await.unwrap(), 0);
1643 }
1644 }
1645 }
1646
1647 #[tokio::test]
1648 async fn init_is_atomic_when_embeddings_provider_initialization_fails() {
1649 let temp_dir = tempdir().unwrap();
1650 let mut rkit = RKit::new(
1651 demo_engine(temp_dir.path().to_str().unwrap(), 384),
1652 missing_local_ort_provider(),
1653 )
1654 .unwrap();
1655
1656 let error = rkit.init().await.unwrap_err();
1657
1658 assert!(matches!(
1659 error,
1660 RKitInitError::EmbeddingsProvider(EmbeddingError::MissingAsset { .. })
1661 ));
1662 assert!(rkit.lancedb_backend().is_none());
1663 assert!(rkit.ort_embedder().is_none());
1664 }
1665
1666 #[tokio::test]
1667 async fn init_dimension_validation_rejects_mismatched_embedder() {
1668 let temp_dir = tempdir().unwrap();
1669 let db_engine = super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 5))
1670 .await
1671 .unwrap();
1672 let embeddings_provider = InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3));
1673
1674 let error =
1675 super::validate_embedding_dimensions(&db_engine, &embeddings_provider).unwrap_err();
1676
1677 assert!(matches!(
1678 error,
1679 RKitInitError::EmbeddingDimensionMismatch {
1680 db_vector_dimensions: 5,
1681 embedding_dimensions: 3
1682 }
1683 ));
1684 }
1685
1686 #[tokio::test]
1687 async fn selecting_new_db_engine_clears_initialized_backend() {
1688 let temp_dir = tempdir().unwrap();
1689 let mut rkit = RKit::new(
1690 demo_engine(temp_dir.path().to_str().unwrap(), 384),
1691 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1692 )
1693 .unwrap();
1694 rkit.db_engine = Some(
1695 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 384))
1696 .await
1697 .unwrap(),
1698 );
1699
1700 assert!(rkit.db_engine_initialized());
1701
1702 rkit.select_db_engine(demo_engine("/tmp/rkit-b", 768))
1703 .unwrap();
1704
1705 assert!(!rkit.db_engine_initialized());
1706 }
1707
1708 #[tokio::test]
1709 async fn ingest_document_requires_initialization() {
1710 let temp_dir = tempdir().unwrap();
1711 let mut rkit = RKit::new(
1712 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1713 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1714 )
1715 .unwrap();
1716
1717 let error = rkit
1718 .ingest_document("A short document".to_string())
1719 .await
1720 .unwrap_err();
1721
1722 assert!(matches!(error, IngestDocumentError::NotInitialized));
1723 }
1724
1725 #[tokio::test]
1726 async fn ingest_document_rejects_empty_content() {
1727 let temp_dir = tempdir().unwrap();
1728 let mut rkit = RKit::new(
1729 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1730 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1731 )
1732 .unwrap();
1733 rkit.db_engine = Some(
1734 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
1735 .await
1736 .unwrap(),
1737 );
1738 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
1739 .await
1740 .unwrap();
1741 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
1742
1743 let error = rkit.ingest_document(" ".to_string()).await.unwrap_err();
1744
1745 assert!(matches!(error, IngestDocumentError::EmptyContent));
1746 }
1747
1748 #[tokio::test]
1749 async fn ingest_document_file_requires_initialization_after_reading() {
1750 let temp_dir = tempdir().unwrap();
1751 let mut rkit = RKit::new(
1752 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1753 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1754 )
1755 .unwrap();
1756 let file_path = temp_dir.path().join("document.txt");
1757 fs::write(&file_path, "A short document").unwrap();
1758
1759 let error = rkit.ingest_document_file(&file_path).await.unwrap_err();
1760
1761 assert!(matches!(error, IngestDocumentError::NotInitialized));
1762 }
1763
1764 #[tokio::test]
1765 async fn ingest_document_file_reads_and_inserts_text_file() {
1766 let temp_dir = tempdir().unwrap();
1767 let mut rkit = RKit::new(
1768 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1769 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1770 )
1771 .unwrap();
1772 rkit.set_chunking_config(ChunkingConfig {
1773 chunk_size: 24,
1774 overlap_size: 4,
1775 })
1776 .unwrap();
1777 rkit.db_engine = Some(
1778 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
1779 .await
1780 .unwrap(),
1781 );
1782 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
1783 .await
1784 .unwrap();
1785 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
1786
1787 let content = "First sentence has enough text to force chunking. Second sentence adds more words for another chunk.";
1788 let file_path = temp_dir.path().join("document.txt");
1789 fs::write(&file_path, content).unwrap();
1790 let expected_chunk_count = super::chunk_text(content, rkit.chunking_config())
1791 .unwrap()
1792 .len();
1793
1794 let result = rkit.ingest_document_file(&file_path).await.unwrap();
1795
1796 assert_eq!(result.chunk_count, expected_chunk_count);
1797 assert!(!result.document_id.is_empty());
1798 assert_eq!(
1799 rkit.get_document(result.document_id.clone()).await.unwrap(),
1800 Some(Document {
1801 document_id: result.document_id,
1802 content: content.to_string(),
1803 })
1804 );
1805
1806 let backend = rkit.lancedb_backend().unwrap();
1807 let table = backend
1808 .connection()
1809 .open_table("chunks")
1810 .execute()
1811 .await
1812 .unwrap();
1813 assert_eq!(
1814 table.count_rows(None).await.unwrap(),
1815 expected_chunk_count as usize
1816 );
1817 }
1818
1819 #[tokio::test]
1820 async fn ingest_document_file_returns_file_read_error_for_missing_path() {
1821 let temp_dir = tempdir().unwrap();
1822 let mut rkit = RKit::new(
1823 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1824 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1825 )
1826 .unwrap();
1827 let file_path = temp_dir.path().join("missing.txt");
1828
1829 let error = rkit.ingest_document_file(&file_path).await.unwrap_err();
1830
1831 assert!(matches!(
1832 error,
1833 IngestDocumentError::FileRead { path, .. } if path == file_path
1834 ));
1835 }
1836
1837 #[tokio::test]
1838 async fn ingest_document_file_returns_file_read_error_for_invalid_utf8() {
1839 let temp_dir = tempdir().unwrap();
1840 let mut rkit = RKit::new(
1841 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1842 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1843 )
1844 .unwrap();
1845 let file_path = temp_dir.path().join("invalid.txt");
1846 fs::write(&file_path, [0xff, 0xfe, 0xfd]).unwrap();
1847
1848 let error = rkit.ingest_document_file(&file_path).await.unwrap_err();
1849
1850 assert!(matches!(
1851 error,
1852 IngestDocumentError::FileRead { path, .. } if path == file_path
1853 ));
1854 }
1855
1856 #[tokio::test]
1857 async fn ingest_document_chunks_embeds_and_inserts_rows() {
1858 let temp_dir = tempdir().unwrap();
1859 let mut rkit = RKit::new(
1860 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1861 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1862 )
1863 .unwrap();
1864 rkit.set_chunking_config(ChunkingConfig {
1865 chunk_size: 24,
1866 overlap_size: 4,
1867 })
1868 .unwrap();
1869 rkit.db_engine = Some(
1870 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
1871 .await
1872 .unwrap(),
1873 );
1874 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
1875 .await
1876 .unwrap();
1877 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
1878
1879 let content = "First sentence has enough text to force chunking. Second sentence adds more words for another chunk.".to_string();
1880 let expected_chunk_count = super::chunk_text(&content, rkit.chunking_config())
1881 .unwrap()
1882 .len();
1883
1884 let result = rkit.ingest_document(content.clone()).await.unwrap();
1885
1886 assert_eq!(result.chunk_count, expected_chunk_count);
1887 assert!(!result.document_id.is_empty());
1888 assert_eq!(
1889 rkit.get_document(result.document_id.clone()).await.unwrap(),
1890 Some(Document {
1891 document_id: result.document_id.clone(),
1892 content,
1893 })
1894 );
1895
1896 let backend = rkit.lancedb_backend().unwrap();
1897 let table = backend
1898 .connection()
1899 .open_table("chunks")
1900 .execute()
1901 .await
1902 .unwrap();
1903 assert_eq!(
1904 table.count_rows(None).await.unwrap(),
1905 expected_chunk_count as usize
1906 );
1907
1908 let rows = table.query().execute().await.unwrap();
1909 let batches = rows.try_collect::<Vec<_>>().await.unwrap();
1910 let document_ids = batches
1911 .iter()
1912 .flat_map(|batch| {
1913 batch
1914 .column_by_name("document_id")
1915 .unwrap()
1916 .as_any()
1917 .downcast_ref::<StringArray>()
1918 .unwrap()
1919 .iter()
1920 .flatten()
1921 .map(str::to_owned)
1922 .collect::<Vec<_>>()
1923 })
1924 .collect::<Vec<_>>();
1925
1926 assert_eq!(document_ids.len(), expected_chunk_count);
1927 assert!(document_ids.iter().all(|id| id == &result.document_id));
1928 }
1929
1930 #[tokio::test]
1931 async fn list_documents_requires_initialization() {
1932 let temp_dir = tempdir().unwrap();
1933 let rkit = RKit::new(
1934 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1935 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1936 )
1937 .unwrap();
1938
1939 let error = rkit.list_documents().await.unwrap_err();
1940
1941 assert!(matches!(error, DocumentError::NotInitialized));
1942 }
1943
1944 #[tokio::test]
1945 async fn get_document_requires_initialization() {
1946 let temp_dir = tempdir().unwrap();
1947 let rkit = RKit::new(
1948 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1949 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1950 )
1951 .unwrap();
1952
1953 let error = rkit
1954 .get_document("missing-doc".to_string())
1955 .await
1956 .unwrap_err();
1957
1958 assert!(matches!(error, DocumentError::NotInitialized));
1959 }
1960
1961 #[tokio::test]
1962 async fn delete_document_requires_initialization() {
1963 let temp_dir = tempdir().unwrap();
1964 let rkit = RKit::new(
1965 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1966 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1967 )
1968 .unwrap();
1969
1970 let error = rkit
1971 .delete_document("missing-doc".to_string())
1972 .await
1973 .unwrap_err();
1974
1975 assert!(matches!(error, DocumentError::NotInitialized));
1976 }
1977
1978 #[tokio::test]
1979 async fn vector_search_requires_initialization() {
1980 let temp_dir = tempdir().unwrap();
1981 let rkit = RKit::new(
1982 demo_engine(temp_dir.path().to_str().unwrap(), 3),
1983 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
1984 )
1985 .unwrap();
1986
1987 let error = rkit
1988 .vector_search("meaningful query".to_string(), 3)
1989 .await
1990 .unwrap_err();
1991
1992 assert!(matches!(error, VectorSearchError::NotInitialized));
1993 }
1994
1995 #[tokio::test]
1996 async fn vector_search_rejects_blank_query() {
1997 let temp_dir = tempdir().unwrap();
1998 let mut rkit = RKit::new(
1999 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2000 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2001 )
2002 .unwrap();
2003 rkit.db_engine = Some(
2004 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2005 .await
2006 .unwrap(),
2007 );
2008 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2009 .await
2010 .unwrap();
2011 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2012
2013 let error = rkit.vector_search(" ".to_string(), 3).await.unwrap_err();
2014
2015 assert!(matches!(error, VectorSearchError::EmptyQuery));
2016 }
2017
2018 #[tokio::test]
2019 async fn vector_search_zero_limit_returns_empty_without_embedding() {
2020 let temp_dir = tempdir().unwrap();
2021 let mut rkit = RKit::new(
2022 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2023 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2024 )
2025 .unwrap();
2026 rkit.db_engine = Some(
2027 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2028 .await
2029 .unwrap(),
2030 );
2031 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2032 .await
2033 .unwrap();
2034 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2035
2036 let results = rkit
2037 .vector_search("meaningful query".to_string(), 0)
2038 .await
2039 .unwrap();
2040
2041 assert!(results.is_empty());
2042 match rkit.embeddings_provider.as_ref().unwrap() {
2043 InitializedEmbeddingsProvider::Mock(embedder) => assert_eq!(embedder.call_count(), 0),
2044 InitializedEmbeddingsProvider::Ort(_) => unreachable!(),
2045 }
2046 }
2047
2048 #[tokio::test]
2049 async fn vector_search_embeds_query_and_returns_limited_ranked_chunks() {
2050 let temp_dir = tempdir().unwrap();
2051 let mut rkit = RKit::new(
2052 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2053 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2054 )
2055 .unwrap();
2056 rkit.db_engine = Some(
2057 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2058 .await
2059 .unwrap(),
2060 );
2061 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2062 .await
2063 .unwrap();
2064 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2065 let first = rkit
2066 .upsert_document(
2067 "first-doc".to_string(),
2068 "First document content.".to_string(),
2069 )
2070 .await
2071 .unwrap();
2072 let second = rkit
2073 .upsert_document(
2074 "second-doc".to_string(),
2075 "Second document content.".to_string(),
2076 )
2077 .await
2078 .unwrap();
2079
2080 let results = rkit
2081 .vector_search("content like the second document".to_string(), 1)
2082 .await
2083 .unwrap();
2084
2085 assert_eq!(
2086 results,
2087 vec![VectorSearchResult {
2088 document_id: second.document_id,
2089 text: "Second document content.".to_string(),
2090 distance: 3.0,
2091 }]
2092 );
2093 assert_eq!(first.chunk_count, 1);
2094 match rkit.embeddings_provider.as_ref().unwrap() {
2095 InitializedEmbeddingsProvider::Mock(embedder) => assert_eq!(embedder.call_count(), 3),
2096 InitializedEmbeddingsProvider::Ort(_) => unreachable!(),
2097 }
2098 }
2099
2100 #[tokio::test]
2101 async fn keyword_search_requires_initialization() {
2102 let temp_dir = tempdir().unwrap();
2103 let rkit = RKit::new(
2104 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2105 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2106 )
2107 .unwrap();
2108
2109 let error = rkit
2110 .keyword_search("meaningful query".to_string(), 3)
2111 .await
2112 .unwrap_err();
2113
2114 assert!(matches!(error, KeywordSearchError::NotInitialized));
2115 }
2116
2117 #[tokio::test]
2118 async fn keyword_search_rejects_blank_query() {
2119 let temp_dir = tempdir().unwrap();
2120 let mut rkit = RKit::new(
2121 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2122 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2123 )
2124 .unwrap();
2125 rkit.db_engine = Some(
2126 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2127 .await
2128 .unwrap(),
2129 );
2130 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2131 .await
2132 .unwrap();
2133 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2134
2135 let error = rkit.keyword_search(" ".to_string(), 3).await.unwrap_err();
2136
2137 assert!(matches!(error, KeywordSearchError::EmptyQuery));
2138 }
2139
2140 #[tokio::test]
2141 async fn keyword_search_zero_limit_returns_empty_without_embedding() {
2142 let temp_dir = tempdir().unwrap();
2143 let mut rkit = RKit::new(
2144 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2145 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2146 )
2147 .unwrap();
2148 rkit.db_engine = Some(
2149 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2150 .await
2151 .unwrap(),
2152 );
2153 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2154 .await
2155 .unwrap();
2156 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2157
2158 let results = rkit
2159 .keyword_search("meaningful query".to_string(), 0)
2160 .await
2161 .unwrap();
2162
2163 assert!(results.is_empty());
2164 match rkit.embeddings_provider.as_ref().unwrap() {
2165 InitializedEmbeddingsProvider::Mock(embedder) => assert_eq!(embedder.call_count(), 0),
2166 InitializedEmbeddingsProvider::Ort(_) => unreachable!(),
2167 }
2168 }
2169
2170 #[tokio::test]
2171 async fn keyword_search_returns_limited_ranked_chunks() {
2172 let temp_dir = tempdir().unwrap();
2173 let mut rkit = RKit::new(
2174 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2175 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2176 )
2177 .unwrap();
2178 rkit.db_engine = Some(
2179 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2180 .await
2181 .unwrap(),
2182 );
2183 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2184 .await
2185 .unwrap();
2186 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2187 let result = rkit
2188 .upsert_document(
2189 "keyword-doc".to_string(),
2190 "Rust search database. Plain ranger path.".to_string(),
2191 )
2192 .await
2193 .unwrap();
2194
2195 let results = rkit.keyword_search("rust".to_string(), 1).await.unwrap();
2196
2197 assert_eq!(results.len(), 1);
2198 assert_eq!(results[0].document_id, result.document_id);
2199 assert_eq!(results[0].text, "Rust search database. Plain ranger path.");
2200 assert!(results[0].score > 0.0);
2201 match rkit.embeddings_provider.as_ref().unwrap() {
2202 InitializedEmbeddingsProvider::Mock(embedder) => assert_eq!(embedder.call_count(), 1),
2203 InitializedEmbeddingsProvider::Ort(_) => unreachable!(),
2204 }
2205 }
2206
2207 #[tokio::test]
2208 async fn ingest_document_propagates_embedding_dimension_mismatch() {
2209 let temp_dir = tempdir().unwrap();
2210 let mut rkit = RKit::new(
2211 demo_engine(temp_dir.path().to_str().unwrap(), 5),
2212 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2213 )
2214 .unwrap();
2215 rkit.db_engine = Some(
2216 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 5))
2217 .await
2218 .unwrap(),
2219 );
2220 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2221 .await
2222 .unwrap();
2223 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2224
2225 let error = rkit
2226 .ingest_document("A document with one chunk.".to_string())
2227 .await
2228 .unwrap_err();
2229
2230 assert!(matches!(error, IngestDocumentError::DbEngine(_)));
2231
2232 let backend = rkit.lancedb_backend().unwrap();
2233 let table = backend
2234 .connection()
2235 .open_table("chunks")
2236 .execute()
2237 .await
2238 .unwrap();
2239 assert_eq!(table.count_rows(None).await.unwrap(), 0);
2240 assert!(table_stored_documents(&rkit).await.is_empty());
2241 }
2242
2243 #[tokio::test]
2244 async fn repeated_ingest_document_calls_append_distinct_document_ids() {
2245 let temp_dir = tempdir().unwrap();
2246 let mut rkit = RKit::new(
2247 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2248 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2249 )
2250 .unwrap();
2251 rkit.db_engine = Some(
2252 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2253 .await
2254 .unwrap(),
2255 );
2256 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2257 .await
2258 .unwrap();
2259 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2260
2261 let first = rkit
2262 .ingest_document("First document content.".to_string())
2263 .await
2264 .unwrap();
2265 let second = rkit
2266 .ingest_document("Second document content.".to_string())
2267 .await
2268 .unwrap();
2269
2270 assert_ne!(first.document_id, second.document_id);
2271
2272 let backend = rkit.lancedb_backend().unwrap();
2273 let table = backend
2274 .connection()
2275 .open_table("chunks")
2276 .execute()
2277 .await
2278 .unwrap();
2279 assert_eq!(
2280 table.count_rows(None).await.unwrap(),
2281 first.chunk_count + second.chunk_count
2282 );
2283 }
2284
2285 #[tokio::test]
2286 async fn list_and_get_documents_return_stored_full_content() {
2287 let temp_dir = tempdir().unwrap();
2288 let mut rkit = RKit::new(
2289 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2290 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2291 )
2292 .unwrap();
2293 rkit.db_engine = Some(
2294 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2295 .await
2296 .unwrap(),
2297 );
2298 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2299 .await
2300 .unwrap();
2301 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2302
2303 let first = rkit
2304 .upsert_document(
2305 "b-doc".to_string(),
2306 "First full document content.".to_string(),
2307 )
2308 .await
2309 .unwrap();
2310 let second_content = "Second full document content.".to_string();
2311 let second = rkit
2312 .upsert_document("a-doc".to_string(), second_content.clone())
2313 .await
2314 .unwrap();
2315
2316 assert_eq!(
2317 rkit.list_documents().await.unwrap(),
2318 vec![
2319 DocumentSummary {
2320 document_id: second.document_id.clone(),
2321 },
2322 DocumentSummary {
2323 document_id: first.document_id,
2324 },
2325 ]
2326 );
2327 assert_eq!(
2328 rkit.get_document(second.document_id.clone()).await.unwrap(),
2329 Some(Document {
2330 document_id: second.document_id,
2331 content: second_content,
2332 })
2333 );
2334 assert_eq!(
2335 rkit.get_document("missing-doc".to_string()).await.unwrap(),
2336 None
2337 );
2338 }
2339
2340 #[tokio::test]
2341 async fn delete_document_removes_stored_content_and_chunks() {
2342 let temp_dir = tempdir().unwrap();
2343 let mut rkit = RKit::new(
2344 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2345 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2346 )
2347 .unwrap();
2348 rkit.db_engine = Some(
2349 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2350 .await
2351 .unwrap(),
2352 );
2353 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2354 .await
2355 .unwrap();
2356 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2357 rkit.upsert_document("manual-doc".to_string(), "Document content.".to_string())
2358 .await
2359 .unwrap();
2360
2361 rkit.delete_document("manual-doc".to_string())
2362 .await
2363 .unwrap();
2364 rkit.delete_document("missing-doc".to_string())
2365 .await
2366 .unwrap();
2367
2368 assert_eq!(rkit.list_documents().await.unwrap(), Vec::new());
2369 assert_eq!(table_document_ids(&rkit).await, Vec::<String>::new());
2370 }
2371
2372 #[tokio::test]
2373 async fn upsert_document_inserts_when_document_id_is_new() {
2374 let temp_dir = tempdir().unwrap();
2375 let mut rkit = RKit::new(
2376 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2377 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2378 )
2379 .unwrap();
2380 rkit.db_engine = Some(
2381 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2382 .await
2383 .unwrap(),
2384 );
2385 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2386 .await
2387 .unwrap();
2388 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2389
2390 let result = rkit
2391 .upsert_document(
2392 "manual-doc".to_string(),
2393 "New document content.".to_string(),
2394 )
2395 .await
2396 .unwrap();
2397
2398 assert_eq!(result.document_id, "manual-doc");
2399 assert_eq!(result.chunk_count, 1);
2400 assert_eq!(
2401 rkit.get_document("manual-doc".to_string()).await.unwrap(),
2402 Some(Document {
2403 document_id: "manual-doc".to_string(),
2404 content: "New document content.".to_string(),
2405 })
2406 );
2407 assert_eq!(
2408 table_document_ids(&rkit).await,
2409 vec!["manual-doc".to_string()]
2410 );
2411 }
2412
2413 #[tokio::test]
2414 async fn upsert_document_replaces_existing_chunks_for_document_id() {
2415 let temp_dir = tempdir().unwrap();
2416 let mut rkit = RKit::new(
2417 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2418 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2419 )
2420 .unwrap();
2421 rkit.set_chunking_config(ChunkingConfig {
2422 chunk_size: 24,
2423 overlap_size: 4,
2424 })
2425 .unwrap();
2426 rkit.db_engine = Some(
2427 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2428 .await
2429 .unwrap(),
2430 );
2431 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2432 .await
2433 .unwrap();
2434 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2435
2436 rkit.upsert_document(
2437 "manual-doc".to_string(),
2438 "First sentence has enough text to force chunking. Second sentence adds more words."
2439 .to_string(),
2440 )
2441 .await
2442 .unwrap();
2443 let replacement = "Short replacement.".to_string();
2444 let result = rkit
2445 .upsert_document("manual-doc".to_string(), replacement.clone())
2446 .await
2447 .unwrap();
2448
2449 assert_eq!(result.document_id, "manual-doc");
2450 assert_eq!(result.chunk_count, 1);
2451 assert_eq!(
2452 table_document_ids(&rkit).await,
2453 vec!["manual-doc".to_string()]
2454 );
2455 assert_eq!(table_texts(&rkit).await, vec![replacement.clone()]);
2456 assert_eq!(
2457 rkit.get_document("manual-doc".to_string()).await.unwrap(),
2458 Some(Document {
2459 document_id: "manual-doc".to_string(),
2460 content: replacement,
2461 })
2462 );
2463 }
2464
2465 #[tokio::test]
2466 async fn upsert_document_regenerates_embeddings_for_replacement_content() {
2467 let temp_dir = tempdir().unwrap();
2468 let mut rkit = RKit::new(
2469 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2470 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2471 )
2472 .unwrap();
2473 rkit.db_engine = Some(
2474 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2475 .await
2476 .unwrap(),
2477 );
2478 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2479 .await
2480 .unwrap();
2481 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2482
2483 rkit.upsert_document("manual-doc".to_string(), "Original content.".to_string())
2484 .await
2485 .unwrap();
2486 assert_eq!(table_vectors(&rkit).await, vec![vec![1.0, 1.0, 1.0]]);
2487
2488 rkit.upsert_document("manual-doc".to_string(), "Replacement content.".to_string())
2489 .await
2490 .unwrap();
2491
2492 assert_eq!(table_vectors(&rkit).await, vec![vec![2.0, 2.0, 2.0]]);
2493 match rkit.embeddings_provider.as_ref().unwrap() {
2494 InitializedEmbeddingsProvider::Mock(embedder) => assert_eq!(embedder.call_count(), 2),
2495 InitializedEmbeddingsProvider::Ort(_) => unreachable!(),
2496 }
2497 }
2498
2499 #[tokio::test]
2500 async fn upsert_document_rejects_empty_content_without_mutating_rows() {
2501 let temp_dir = tempdir().unwrap();
2502 let mut rkit = RKit::new(
2503 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2504 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2505 )
2506 .unwrap();
2507 rkit.db_engine = Some(
2508 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2509 .await
2510 .unwrap(),
2511 );
2512 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2513 .await
2514 .unwrap();
2515 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2516 rkit.upsert_document("manual-doc".to_string(), "Original content.".to_string())
2517 .await
2518 .unwrap();
2519
2520 let error = rkit
2521 .upsert_document("manual-doc".to_string(), " ".to_string())
2522 .await
2523 .unwrap_err();
2524
2525 assert!(matches!(error, IngestDocumentError::EmptyContent));
2526 assert_eq!(
2527 table_texts(&rkit).await,
2528 vec!["Original content.".to_string()]
2529 );
2530 assert_eq!(
2531 rkit.get_document("manual-doc".to_string()).await.unwrap(),
2532 Some(Document {
2533 document_id: "manual-doc".to_string(),
2534 content: "Original content.".to_string(),
2535 })
2536 );
2537 }
2538
2539 #[tokio::test]
2540 async fn upsert_document_preserves_existing_rows_when_vectors_are_invalid() {
2541 let temp_dir = tempdir().unwrap();
2542 let mut rkit = RKit::new(
2543 demo_engine(temp_dir.path().to_str().unwrap(), 5),
2544 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2545 )
2546 .unwrap();
2547 rkit.db_engine = Some(
2548 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 5))
2549 .await
2550 .unwrap(),
2551 );
2552 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2553 .await
2554 .unwrap();
2555 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(5)));
2556 rkit.upsert_document("manual-doc".to_string(), "Original content.".to_string())
2557 .await
2558 .unwrap();
2559 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2560
2561 let error = rkit
2562 .upsert_document("manual-doc".to_string(), "Replacement content.".to_string())
2563 .await
2564 .unwrap_err();
2565
2566 assert!(matches!(error, IngestDocumentError::DbEngine(_)));
2567 assert_eq!(
2568 table_texts(&rkit).await,
2569 vec!["Original content.".to_string()]
2570 );
2571 assert_eq!(
2572 rkit.get_document("manual-doc".to_string()).await.unwrap(),
2573 Some(Document {
2574 document_id: "manual-doc".to_string(),
2575 content: "Original content.".to_string(),
2576 })
2577 );
2578 }
2579
2580 #[tokio::test]
2581 async fn ingest_documents_requires_initialization() {
2582 let temp_dir = tempdir().unwrap();
2583 let mut rkit = RKit::new(
2584 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2585 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2586 )
2587 .unwrap();
2588
2589 let error = rkit
2590 .ingest_documents(vec!["A short document".to_string()])
2591 .await
2592 .unwrap_err();
2593
2594 assert!(matches!(error, IngestDocumentError::NotInitialized));
2595 }
2596
2597 #[tokio::test]
2598 async fn ingest_documents_rejects_empty_batch() {
2599 let temp_dir = tempdir().unwrap();
2600 let mut rkit = RKit::new(
2601 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2602 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2603 )
2604 .unwrap();
2605 rkit.db_engine = Some(
2606 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2607 .await
2608 .unwrap(),
2609 );
2610 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2611 .await
2612 .unwrap();
2613 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2614
2615 let error = rkit.ingest_documents(Vec::new()).await.unwrap_err();
2616
2617 assert!(matches!(error, IngestDocumentError::EmptyContent));
2618 }
2619
2620 #[tokio::test]
2621 async fn ingest_document_files_reads_multiple_files_from_glob() {
2622 let temp_dir = tempdir().unwrap();
2623 let mut rkit = RKit::new(
2624 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2625 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2626 )
2627 .unwrap();
2628 rkit.set_chunking_config(ChunkingConfig {
2629 chunk_size: 24,
2630 overlap_size: 4,
2631 })
2632 .unwrap();
2633 rkit.db_engine = Some(
2634 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2635 .await
2636 .unwrap(),
2637 );
2638 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2639 .await
2640 .unwrap();
2641 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2642
2643 let first_content =
2644 "First sentence has enough text to force chunking. Second sentence adds more words.";
2645 let second_content = "Another document with enough content to create at least one chunk.";
2646 fs::write(temp_dir.path().join("a.txt"), first_content).unwrap();
2647 fs::write(temp_dir.path().join("b.txt"), second_content).unwrap();
2648 let expected_first = super::chunk_text(first_content, rkit.chunking_config())
2649 .unwrap()
2650 .len();
2651 let expected_second = super::chunk_text(second_content, rkit.chunking_config())
2652 .unwrap()
2653 .len();
2654 let pattern = format!("{}/*.txt", temp_dir.path().display());
2655
2656 let results = rkit.ingest_document_files(&pattern).await.unwrap();
2657
2658 assert_eq!(results.len(), 2);
2659 assert_eq!(results[0].chunk_count, expected_first);
2660 assert_eq!(results[1].chunk_count, expected_second);
2661 assert_ne!(results[0].document_id, results[1].document_id);
2662
2663 let backend = rkit.lancedb_backend().unwrap();
2664 let table = backend
2665 .connection()
2666 .open_table("chunks")
2667 .execute()
2668 .await
2669 .unwrap();
2670 assert_eq!(
2671 table.count_rows(None).await.unwrap(),
2672 expected_first + expected_second
2673 );
2674 }
2675
2676 #[tokio::test]
2677 async fn ingest_document_files_rejects_invalid_glob_pattern() {
2678 let temp_dir = tempdir().unwrap();
2679 let mut rkit = RKit::new(
2680 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2681 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2682 )
2683 .unwrap();
2684
2685 let error = rkit.ingest_document_files("[").await.unwrap_err();
2686
2687 assert!(matches!(error, IngestDocumentError::InvalidGlobPattern(_)));
2688 }
2689
2690 #[tokio::test]
2691 async fn ingest_document_files_rejects_when_glob_matches_nothing() {
2692 let temp_dir = tempdir().unwrap();
2693 let mut rkit = RKit::new(
2694 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2695 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2696 )
2697 .unwrap();
2698 let pattern = format!("{}/*.txt", temp_dir.path().display());
2699
2700 let error = rkit.ingest_document_files(&pattern).await.unwrap_err();
2701
2702 assert!(matches!(
2703 error,
2704 IngestDocumentError::NoFilesMatched { pattern: actual } if actual == pattern
2705 ));
2706 }
2707
2708 #[tokio::test]
2709 async fn ingest_document_files_prevalidates_all_matched_files_before_inserting() {
2710 let temp_dir = tempdir().unwrap();
2711 let mut rkit = RKit::new(
2712 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2713 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2714 )
2715 .unwrap();
2716 rkit.db_engine = Some(
2717 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2718 .await
2719 .unwrap(),
2720 );
2721 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2722 .await
2723 .unwrap();
2724 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2725
2726 fs::write(temp_dir.path().join("valid.txt"), "valid document").unwrap();
2727 fs::create_dir(temp_dir.path().join("invalid.txt")).unwrap();
2728 let pattern = format!("{}/*.txt", temp_dir.path().display());
2729
2730 let error = rkit.ingest_document_files(&pattern).await.unwrap_err();
2731
2732 assert!(matches!(error, IngestDocumentError::FileRead { .. }));
2733
2734 let backend = rkit.lancedb_backend().unwrap();
2735 let table = backend
2736 .connection()
2737 .open_table("chunks")
2738 .execute()
2739 .await
2740 .unwrap();
2741 assert_eq!(table.count_rows(None).await.unwrap(), 0);
2742 assert!(table_stored_documents(&rkit).await.is_empty());
2743 }
2744
2745 #[tokio::test]
2746 async fn ingest_documents_prevalidates_all_inputs_before_inserting() {
2747 let temp_dir = tempdir().unwrap();
2748 let mut rkit = RKit::new(
2749 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2750 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2751 )
2752 .unwrap();
2753 rkit.db_engine = Some(
2754 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2755 .await
2756 .unwrap(),
2757 );
2758 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2759 .await
2760 .unwrap();
2761 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2762
2763 let error = rkit
2764 .ingest_documents(vec!["valid document".to_string(), " ".to_string()])
2765 .await
2766 .unwrap_err();
2767
2768 assert!(matches!(error, IngestDocumentError::EmptyContent));
2769
2770 let backend = rkit.lancedb_backend().unwrap();
2771 let table = backend
2772 .connection()
2773 .open_table("chunks")
2774 .execute()
2775 .await
2776 .unwrap();
2777 assert_eq!(table.count_rows(None).await.unwrap(), 0);
2778 assert!(table_stored_documents(&rkit).await.is_empty());
2779 }
2780
2781 #[tokio::test]
2782 async fn ingest_documents_returns_one_result_per_input_document() {
2783 let temp_dir = tempdir().unwrap();
2784 let mut rkit = RKit::new(
2785 demo_engine(temp_dir.path().to_str().unwrap(), 3),
2786 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2787 )
2788 .unwrap();
2789 rkit.set_chunking_config(ChunkingConfig {
2790 chunk_size: 24,
2791 overlap_size: 4,
2792 })
2793 .unwrap();
2794 rkit.db_engine = Some(
2795 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 3))
2796 .await
2797 .unwrap(),
2798 );
2799 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2800 .await
2801 .unwrap();
2802 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2803
2804 let first_content =
2805 "First sentence has enough text to force chunking. Second sentence adds more words."
2806 .to_string();
2807 let second_content =
2808 "Another document with enough content to create at least one chunk.".to_string();
2809 let expected_first = super::chunk_text(&first_content, rkit.chunking_config())
2810 .unwrap()
2811 .len();
2812 let expected_second = super::chunk_text(&second_content, rkit.chunking_config())
2813 .unwrap()
2814 .len();
2815
2816 let results = rkit
2817 .ingest_documents(vec![first_content.clone(), second_content.clone()])
2818 .await
2819 .unwrap();
2820
2821 assert_eq!(results.len(), 2);
2822 assert_eq!(results[0].chunk_count, expected_first);
2823 assert_eq!(results[1].chunk_count, expected_second);
2824 assert_ne!(results[0].document_id, results[1].document_id);
2825 assert_eq!(
2826 rkit.get_document(results[0].document_id.clone())
2827 .await
2828 .unwrap(),
2829 Some(Document {
2830 document_id: results[0].document_id.clone(),
2831 content: first_content,
2832 })
2833 );
2834 assert_eq!(
2835 rkit.get_document(results[1].document_id.clone())
2836 .await
2837 .unwrap(),
2838 Some(Document {
2839 document_id: results[1].document_id.clone(),
2840 content: second_content,
2841 })
2842 );
2843
2844 let backend = rkit.lancedb_backend().unwrap();
2845 let table = backend
2846 .connection()
2847 .open_table("chunks")
2848 .execute()
2849 .await
2850 .unwrap();
2851 assert_eq!(
2852 table.count_rows(None).await.unwrap(),
2853 expected_first + expected_second
2854 );
2855
2856 let rows = table.query().execute().await.unwrap();
2857 let batches = rows.try_collect::<Vec<_>>().await.unwrap();
2858 let document_ids = batches
2859 .iter()
2860 .flat_map(|batch| {
2861 batch
2862 .column_by_name("document_id")
2863 .unwrap()
2864 .as_any()
2865 .downcast_ref::<StringArray>()
2866 .unwrap()
2867 .iter()
2868 .flatten()
2869 .map(str::to_owned)
2870 .collect::<Vec<_>>()
2871 })
2872 .collect::<Vec<_>>();
2873
2874 let first_count = document_ids
2875 .iter()
2876 .filter(|id| *id == &results[0].document_id)
2877 .count();
2878 let second_count = document_ids
2879 .iter()
2880 .filter(|id| *id == &results[1].document_id)
2881 .count();
2882
2883 assert_eq!(first_count, expected_first);
2884 assert_eq!(second_count, expected_second);
2885 }
2886
2887 #[tokio::test]
2888 async fn ingest_documents_propagates_embedding_dimension_mismatch_without_inserting() {
2889 let temp_dir = tempdir().unwrap();
2890 let mut rkit = RKit::new(
2891 demo_engine(temp_dir.path().to_str().unwrap(), 5),
2892 EmbeddingsProviderKind::Ort(EmbeddingsConfig::default()),
2893 )
2894 .unwrap();
2895 rkit.db_engine = Some(
2896 super::init_db_engine(demo_engine(temp_dir.path().to_str().unwrap(), 5))
2897 .await
2898 .unwrap(),
2899 );
2900 super::ensure_db_engine_tables(rkit.db_engine.as_ref().unwrap())
2901 .await
2902 .unwrap();
2903 rkit.embeddings_provider = Some(InitializedEmbeddingsProvider::Mock(MockEmbedder::new(3)));
2904
2905 let error = rkit
2906 .ingest_documents(vec![
2907 "First document content.".to_string(),
2908 "Second document content.".to_string(),
2909 ])
2910 .await
2911 .unwrap_err();
2912
2913 assert!(matches!(error, IngestDocumentError::DbEngine(_)));
2914
2915 let backend = rkit.lancedb_backend().unwrap();
2916 let table = backend
2917 .connection()
2918 .open_table("chunks")
2919 .execute()
2920 .await
2921 .unwrap();
2922 assert_eq!(table.count_rows(None).await.unwrap(), 0);
2923 }
2924}