1use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use tokio::sync::RwLock as TokioRwLock;
8
9use cognee_database::{DatabaseConnection, connect, initialize};
10use cognee_embedding::{EmbeddingConfig, EmbeddingEngine, EmbeddingProvider};
11use cognee_graph::GraphDBTrait;
12#[cfg(feature = "ladybug")]
13use cognee_graph::LadybugAdapter;
14#[cfg(feature = "pggraph")]
15use cognee_graph::PgGraphAdapter;
16use cognee_llm::{Llm, OpenAIAdapter, Transcriber};
17use cognee_storage::{LocalStorage, StorageTrait};
18#[cfg(feature = "pgvector")]
19use cognee_vector::PgVectorAdapter;
20use cognee_vector::{BruteForceVectorDB, VectorDB};
21
22use crate::config::{ConfigManager, Settings};
23use crate::context::PipelineContext;
24use crate::error::ComponentError;
25
26#[cfg(any(feature = "pgvector", feature = "pggraph"))]
29fn build_postgres_url(
30 host: &str,
31 port: u16,
32 name: &str,
33 user: &str,
34 pass: &str,
35) -> Result<String, String> {
36 #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
37 let mut parsed = url::Url::parse("postgres://localhost").expect("static URL is always valid");
38 parsed
39 .set_host(Some(host))
40 .map_err(|e| format!("invalid host '{host}': {e}"))?;
41 parsed
42 .set_port(Some(port))
43 .map_err(|_| format!("invalid port {port}"))?;
44 parsed.set_path(&format!("/{name}"));
45 parsed
46 .set_username(user)
47 .map_err(|_| format!("invalid username '{user}'"))?;
48 parsed
49 .set_password(Some(pass))
50 .map_err(|_| "invalid password".to_string())?;
51 Ok(parsed.to_string())
52}
53
54pub struct ComponentManager {
63 config: ConfigManager,
64 storage: TokioRwLock<Option<(u64, Arc<dyn StorageTrait>)>>,
68 database: TokioRwLock<Option<(u64, Arc<DatabaseConnection>)>>,
69 graph_db: TokioRwLock<Option<(u64, Arc<dyn GraphDBTrait>)>>,
70 vector_db: TokioRwLock<Option<(u64, Arc<dyn VectorDB>)>>,
71 embedding_engine: TokioRwLock<Option<(u64, Arc<dyn EmbeddingEngine>)>>,
72 llm: TokioRwLock<Option<(u64, Arc<dyn Llm>)>>,
73 #[allow(clippy::type_complexity)]
77 transcriber: TokioRwLock<Option<(u64, Option<Arc<dyn Transcriber>>)>>,
78}
79
80impl ComponentManager {
81 pub fn new(config: ConfigManager) -> Self {
82 Self {
83 config,
84 storage: TokioRwLock::new(None),
85 database: TokioRwLock::new(None),
86 graph_db: TokioRwLock::new(None),
87 vector_db: TokioRwLock::new(None),
88 embedding_engine: TokioRwLock::new(None),
89 llm: TokioRwLock::new(None),
90 transcriber: TokioRwLock::new(None),
91 }
92 }
93
94 pub fn settings(&self) -> std::sync::RwLockReadGuard<'_, Settings> {
99 self.config.read()
100 }
101
102 pub fn config(&self) -> &ConfigManager {
104 &self.config
105 }
106
107 async fn init_storage(&self) -> Result<Arc<dyn StorageTrait>, ComponentError> {
108 let data_root = self.config.read().data_root_directory.clone();
109 let storage = LocalStorage::new(PathBuf::from(&data_root));
110 storage
111 .initialize()
112 .await
113 .map_err(|e| ComponentError::Storage(format!("initialization failed: {e}")))?;
114 Ok(Arc::new(storage))
115 }
116
117 async fn init_database(&self) -> Result<Arc<DatabaseConnection>, ComponentError> {
118 let url = self.config.read().resolved_relational_db_url();
119
120 if url.starts_with("sqlite:") && !url.contains(":memory:") {
132 let after_scheme = url.trim_start_matches("sqlite:");
135 let path_part = if after_scheme.starts_with("//localhost/") {
136 Some(&after_scheme["//localhost".len()..])
137 } else if after_scheme.starts_with("///") {
138 Some(&after_scheme[2..])
140 } else if after_scheme.starts_with("//") {
141 None
144 } else {
145 Some(after_scheme)
146 };
147 if let Some(path_part) = path_part {
149 let path_no_query = path_part.split('?').next().unwrap_or(path_part);
150 let db_path = Path::new(path_no_query);
151 if let Some(parent) = db_path.parent()
152 && !parent.as_os_str().is_empty()
153 {
154 if let Err(e) = std::fs::create_dir_all(parent) {
157 tracing::warn!(
158 "could not create SQLite parent directory '{}': {e}",
159 parent.display()
160 );
161 }
162 }
163 }
164 }
165
166 let db = connect(&url)
167 .await
168 .map_err(|e| ComponentError::Database(format!("initialization failed: {e}")))?;
169 initialize(&db)
170 .await
171 .map_err(|e| ComponentError::Database(format!("schema initialization failed: {e}")))?;
172 Ok(Arc::new(db))
173 }
174
175 async fn init_graph_db(&self) -> Result<Arc<dyn GraphDBTrait>, ComponentError> {
176 let provider = self.config.read().graph_database_provider.to_lowercase();
177
178 match provider.as_str() {
179 "ladybug" | "kuzu" => self.init_ladybug_graph_db().await,
180
181 #[cfg(feature = "pggraph")]
182 "postgres" | "postgresql" => {
183 let url = {
184 let s = self.config.read();
185 self.resolved_graph_db_url(&s)?
186 };
187 let adapter = PgGraphAdapter::new(&url)
188 .await
189 .map_err(|e| ComponentError::GraphDb(format!("pggraph init failed: {e}")))?;
190 Ok(Arc::new(adapter))
191 }
192
193 #[cfg(not(feature = "pggraph"))]
194 "postgres" | "postgresql" => Err(ComponentError::Config(
195 "graph_database_provider=postgres requires the `pggraph` crate feature".into(),
196 )),
197
198 other => Err(ComponentError::Config(format!(
199 "Unsupported graph_database_provider '{other}'. Supported: ladybug, kuzu, postgres.",
200 ))),
201 }
202 }
203
204 async fn init_ladybug_graph_db(&self) -> Result<Arc<dyn GraphDBTrait>, ComponentError> {
205 let graph_path = {
206 let s = self.config.read();
207 if !s.graph_file_path.is_empty() {
208 s.graph_file_path.clone()
209 } else {
210 format!("{}/graph", s.system_root_directory)
211 }
212 };
213
214 if let Some(parent) = Path::new(&graph_path).parent() {
215 std::fs::create_dir_all(parent)?;
216 }
217
218 #[cfg(feature = "ladybug")]
219 {
220 let graph_db = LadybugAdapter::new(&graph_path)
221 .await
222 .map_err(|e| ComponentError::GraphDb(format!("initialization failed: {e}")))?;
223 graph_db.initialize().await.map_err(|e| {
224 ComponentError::GraphDb(format!("schema initialization failed: {e}"))
225 })?;
226 Ok(Arc::new(graph_db))
227 }
228
229 #[cfg(not(feature = "ladybug"))]
230 Err(ComponentError::Config(
231 "graph_database_provider=ladybug requires the `ladybug` crate feature".to_string(),
232 ))
233 }
234
235 #[cfg(feature = "pggraph")]
245 fn resolved_graph_db_url(&self, s: &Settings) -> Result<String, ComponentError> {
246 if s.graph_database_url.starts_with("postgres://")
247 || s.graph_database_url.starts_with("postgresql://")
248 {
249 return Ok(s.graph_database_url.clone());
250 }
251
252 let graph_host = if s.graph_database_host.is_empty() {
253 None
254 } else {
255 Some(s.graph_database_host.as_str())
256 };
257
258 let graph_creds_complete = graph_host.is_some()
259 && !s.graph_database_username.is_empty()
260 && !s.graph_database_name.is_empty();
261
262 let (host, port, name, user, pass) = if graph_creds_complete {
263 (
264 #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
265 graph_host.expect("checked above"),
266 s.graph_database_port,
267 s.graph_database_name.as_str(),
268 s.graph_database_username.as_str(),
269 s.graph_database_password.as_str(),
270 )
271 } else {
272 tracing::warn!(
273 "Postgres graph credentials not fully configured; falling back to the \
274 relational database configuration. Set GRAPH_DATABASE_* explicitly to avoid this."
275 );
276 if s.db_host.is_empty() || s.db_name.is_empty() || s.db_username.is_empty() {
277 return Err(ComponentError::Config(
278 "Missing required Postgres graph credentials".into(),
279 ));
280 }
281 (
282 s.db_host.as_str(),
283 s.db_port,
284 s.db_name.as_str(),
285 s.db_username.as_str(),
286 s.db_password.as_str(),
287 )
288 };
289
290 build_postgres_url(host, port, name, user, pass)
291 .map_err(|e| ComponentError::Config(format!("failed to build graph DB URL: {e}")))
292 }
293
294 async fn init_vector_db(&self) -> Result<Arc<dyn VectorDB>, ComponentError> {
295 let (provider, _dim) = {
297 let s = self.config.read();
298 (
299 s.vector_db_provider.to_lowercase(),
300 s.embedding_dimensions as usize,
301 )
302 };
303
304 match provider.as_str() {
305 "pgvector" => {
306 #[cfg(feature = "pgvector")]
307 {
308 let url = {
309 let s = self.config.read();
310 self.resolved_vector_db_url(&s)?
311 };
312 let adapter = PgVectorAdapter::new(&url, _dim).await.map_err(|e| {
313 ComponentError::VectorDb(format!("pgvector init failed: {e}"))
314 })?;
315 Ok(Arc::new(adapter))
316 }
317
318 #[cfg(not(feature = "pgvector"))]
319 Err(ComponentError::Config(
320 "vector_db_provider=pgvector requires the `pgvector` crate feature".to_string(),
321 ))
322 }
323 "brute-force" | "brute_force" | "bruteforce" => Ok(Arc::new(BruteForceVectorDB::new())),
325 "lancedb" => {
330 let url = {
331 let s = self.config.read();
332 s.vector_db_url.clone()
333 };
334 if url == ":memory:" {
335 return Ok(Arc::new(BruteForceVectorDB::new()));
336 }
337 #[cfg(not(target_os = "android"))]
338 {
339 let path = self.resolved_lancedb_path(&url);
340 let adapter = cognee_vector::LanceDbAdapter::new(path)
341 .await
342 .map_err(|e| {
343 ComponentError::VectorDb(format!("lancedb init failed: {e}"))
344 })?;
345 Ok(Arc::new(adapter))
346 }
347 #[cfg(target_os = "android")]
348 {
349 tracing::warn!(
350 "vector_db_provider='lancedb' is not available on Android; \
351 falling back to in-memory brute-force. Set \
352 vector_db_provider='pgvector' for production durable storage."
353 );
354 Ok(Arc::new(BruteForceVectorDB::new()))
355 }
356 }
357 "qdrant" => Err(ComponentError::Config(format!(
363 "vector_db_provider='{provider}' is not available in this build. \
364 The Qdrant adapter has been extracted to the closed \
365 cognee-vector-qdrant crate. Use vector_db_provider='pgvector', \
366 'lancedb', or 'brute-force' in OSS."
367 ))),
368 #[cfg(feature = "testing")]
369 "mock" => Ok(Arc::new(cognee_vector::MockVectorDB::new())),
370 other => Err(ComponentError::Config(format!(
371 "Unsupported vector_db_provider '{other}'. \
372 Supported: pgvector, lancedb (non-Android), brute-force, mock (testing feature only).",
373 ))),
374 }
375 }
376
377 #[cfg(not(target_os = "android"))]
384 fn resolved_lancedb_path(&self, vector_db_url: &str) -> std::path::PathBuf {
385 use std::path::PathBuf;
386 if !vector_db_url.is_empty() {
387 return PathBuf::from(vector_db_url);
388 }
389 let root = {
390 let s = self.config.read();
391 s.system_root_directory.clone()
392 };
393 PathBuf::from(root).join("databases").join("cognee.lancedb")
394 }
395
396 #[cfg(feature = "pgvector")]
403 fn resolved_vector_db_url(&self, settings: &Settings) -> Result<String, ComponentError> {
404 if settings.vector_db_url.starts_with("postgres://")
405 || settings.vector_db_url.starts_with("postgresql://")
406 {
407 return Ok(settings.vector_db_url.clone());
408 }
409
410 let host = if settings.vector_db_url.is_empty() {
411 "localhost"
412 } else {
413 &settings.vector_db_url
414 };
415 let port = settings.vector_db_port;
416 let name = if settings.vector_db_name.is_empty() {
417 "cognee_vectors"
418 } else {
419 &settings.vector_db_name
420 };
421 let user = if settings.db_username.is_empty() {
422 "postgres"
423 } else {
424 &settings.db_username
425 };
426 let pass = &settings.db_password;
427
428 build_postgres_url(host, port, name, user, pass)
429 .map_err(|e| ComponentError::Config(format!("failed to build vector DB URL: {e}")))
430 }
431
432 async fn init_embedding_engine(&self) -> Result<Arc<dyn EmbeddingEngine>, ComponentError> {
438 let mut config = {
441 let settings = self.config.read();
442
443 let provider_str = settings.embedding_provider.trim().to_lowercase();
445 let provider = match provider_str.as_str() {
446 "onnx" => EmbeddingProvider::Onnx,
447 "fastembed" => EmbeddingProvider::Fastembed,
448 "openai" => EmbeddingProvider::OpenAi,
449 "openai_compatible" => EmbeddingProvider::OpenAiCompatible,
450 "ollama" => EmbeddingProvider::Ollama,
451 "mock" => EmbeddingProvider::Mock,
452 _ => EmbeddingProvider::Onnx,
453 };
454
455 let endpoint = [&settings.embedding_endpoint, &settings.llm_endpoint]
462 .into_iter()
463 .find(|v| !v.is_empty())
464 .cloned();
465
466 let api_key = [&settings.embedding_api_key, &settings.llm_api_key]
467 .into_iter()
468 .find(|v| !v.is_empty())
469 .cloned();
470
471 let mock_mode = std::env::var("MOCK_EMBEDDING")
475 .ok()
476 .map(|v| v.trim().to_lowercase());
477 let mock_deterministic =
478 matches!(mock_mode.as_deref(), Some("deterministic") | Some("hash"));
479 let mock = mock_deterministic
480 || matches!(mock_mode.as_deref(), Some("true") | Some("1") | Some("yes"));
481 let mock_mode = if mock_deterministic {
482 cognee_embedding::MockVectorMode::Deterministic
483 } else {
484 cognee_embedding::MockVectorMode::Zero
485 };
486
487 EmbeddingConfig {
488 provider: if mock {
489 EmbeddingProvider::Mock
490 } else {
491 provider
492 },
493 model: settings.embedding_model_name.clone(),
494 dimensions: settings.embedding_dimensions as usize,
495 endpoint,
496 api_key,
497 api_version: None,
498 max_completion_tokens: 8191,
499 batch_size: settings.embedding_batch_size as usize,
500 mock,
501 mock_mode,
502 #[cfg(feature = "onnx")]
503 onnx: cognee_embedding::OnnxEmbeddingConfig {
504 model_path: PathBuf::from(&settings.embedding_model_path),
505 tokenizer_path: PathBuf::from(&settings.embedding_tokenizer_path),
506 model_name: settings.embedding_model_name.clone(),
507 dimensions: settings.embedding_dimensions as usize,
508 max_sequence_length: settings.embedding_max_sequence_length as usize,
509 batch_size: settings.embedding_batch_size as usize,
510 },
511 huggingface_tokenizer: None,
512 }
513 };
514 if let Ok(val) = std::env::var("EMBEDDING_API_VERSION") {
519 let val = val.trim().to_string();
520 if !val.is_empty() {
521 config.api_version = Some(val);
522 }
523 }
524 if let Ok(val) = std::env::var("HUGGINGFACE_TOKENIZER") {
525 let val = val.trim().to_string();
526 if !val.is_empty() {
527 config.huggingface_tokenizer = Some(val);
528 }
529 }
530 if let Ok(val) = std::env::var("EMBEDDING_MAX_COMPLETION_TOKENS")
531 && let Ok(n) = val.trim().parse::<usize>()
532 {
533 config.max_completion_tokens = n;
534 }
535
536 config.create_engine().await.map_err(|e| {
537 ComponentError::EmbeddingEngine(format!("embedding engine init failed: {e}"))
538 })
539 }
540
541 async fn init_llm(&self) -> Result<Arc<dyn Llm>, ComponentError> {
542 let (
544 provider,
545 llm_model,
546 llm_api_key,
547 llm_endpoint,
548 llm_max_retries,
549 llm_mock,
550 llm_cassette,
551 llm_record_path,
552 ) = {
553 let s = self.config.read();
554 (
555 s.llm_provider.to_lowercase(),
556 s.llm_model.clone(),
557 s.llm_api_key.clone(),
558 s.llm_endpoint.clone(),
559 s.llm_max_retries,
560 s.llm_mock,
561 s.llm_cassette.clone(),
562 s.llm_record_path.clone(),
563 )
564 };
565
566 #[cfg(not(feature = "mock-llm"))]
569 let _ = &llm_cassette;
570
571 if llm_mock || provider == "mock" {
574 #[cfg(feature = "mock-llm")]
575 {
576 let cassette = llm_cassette.trim();
577 if cassette.is_empty() {
578 return Err(ComponentError::Config(
579 "MOCK_LLM is set but MOCK_LLM_CASSETTE is empty; set it to a cassette path"
580 .to_string(),
581 ));
582 }
583 let replay = cognee_llm::mock::ReplayLlm::from_path(cassette)
584 .map_err(|e| ComponentError::Llm(format!("mock cassette load failed: {e}")))?;
585 return Ok(Arc::new(replay));
586 }
587 #[cfg(not(feature = "mock-llm"))]
588 {
589 return Err(ComponentError::Config(
590 "MOCK_LLM was requested but the mock LLM is unavailable; \
591 rebuild with the `mock-llm` feature"
592 .to_string(),
593 ));
594 }
595 }
596
597 let adapter: Arc<dyn Llm> = match provider.as_str() {
599 "openai" => {
600 if llm_api_key.is_empty() {
601 return Err(ComponentError::Config(
602 "llm_api_key must be configured".to_string(),
603 ));
604 }
605
606 let endpoint = if llm_endpoint.is_empty() {
607 None
608 } else {
609 Some(llm_endpoint)
610 };
611
612 let retries = llm_max_retries.max(1);
613
614 let adapter = OpenAIAdapter::new(llm_model, llm_api_key, endpoint)
615 .map_err(|e| ComponentError::Llm(format!("initialization failed: {e}")))?
616 .with_structured_output_retries(retries)
617 .with_network_retries(retries);
618
619 Arc::new(adapter)
620 }
621 "litert" => {
622 return Err(ComponentError::Config(
623 "llm_provider=litert is not available in this build. \
624 The LiteRT adapter has been extracted to the closed cognee-llm-litert crate."
625 .to_string(),
626 ));
627 }
628 _ => {
629 return Err(ComponentError::Config(format!(
630 "Unsupported llm_provider '{provider}'. Supported: openai, mock.",
631 )));
632 }
633 };
634
635 if !llm_record_path.trim().is_empty() {
638 #[cfg(feature = "mock-llm")]
639 {
640 let recorder = cognee_llm::mock::RecordingLlm::new(
641 adapter,
642 llm_record_path.trim().to_string(),
643 );
644 return Ok(Arc::new(recorder));
645 }
646 #[cfg(not(feature = "mock-llm"))]
647 {
648 return Err(ComponentError::Config(
649 "COGNEE_RECORD_LLM was set but LLM recording is unavailable; \
650 rebuild with the `mock-llm` feature"
651 .to_string(),
652 ));
653 }
654 }
655
656 Ok(adapter)
657 }
658
659 async fn init_transcriber(&self) -> Result<Option<Arc<dyn Transcriber>>, ComponentError> {
660 let (provider, llm_model, llm_api_key, llm_endpoint, llm_max_retries) = {
661 let s = self.config.read();
662 (
663 s.llm_provider.to_lowercase(),
664 s.llm_model.clone(),
665 s.llm_api_key.clone(),
666 s.llm_endpoint.clone(),
667 s.llm_max_retries,
668 )
669 };
670
671 match provider.as_str() {
672 "openai" => {
673 if llm_api_key.is_empty() {
674 return Err(ComponentError::Config(
675 "llm_api_key must be configured".to_string(),
676 ));
677 }
678
679 let endpoint = if llm_endpoint.is_empty() {
680 None
681 } else {
682 Some(llm_endpoint)
683 };
684
685 let retries = llm_max_retries.max(1);
686
687 let adapter = OpenAIAdapter::new(llm_model, llm_api_key, endpoint)
688 .map_err(|e| ComponentError::Llm(format!("initialization failed: {e}")))?
689 .with_structured_output_retries(retries)
690 .with_network_retries(retries);
691
692 Ok(Some(Arc::new(adapter) as Arc<dyn Transcriber>))
693 }
694 _ => Ok(None),
697 }
698 }
699
700 pub async fn transcriber(&self) -> Result<Option<Arc<dyn Transcriber>>, ComponentError> {
706 let current_ver = self.config.version();
707 {
709 let guard = self.transcriber.read().await;
710 if let Some((ver, ref opt)) = *guard
711 && ver == current_ver
712 {
713 return Ok(opt.clone());
714 }
715 }
716 let mut guard = self.transcriber.write().await;
718 if let Some((ver, ref opt)) = *guard
719 && ver == current_ver
720 {
721 return Ok(opt.clone());
722 }
723 let new = self.init_transcriber().await?;
724 *guard = Some((current_ver, new.clone()));
725 Ok(new)
726 }
727}
728
729macro_rules! versioned_accessor {
732 ($self:ident, $field:ident, $init_fn:ident) => {{
733 let current_ver = $self.config.version();
734 {
736 let guard = $self.$field.read().await;
737 if let Some((ver, ref component)) = *guard {
738 if ver == current_ver {
739 return Ok(Arc::clone(component));
740 }
741 }
742 }
743 let mut guard = $self.$field.write().await;
745 if let Some((ver, ref component)) = *guard {
747 if ver == current_ver {
748 return Ok(Arc::clone(component));
749 }
750 }
751 let new = $self.$init_fn().await?;
752 *guard = Some((current_ver, Arc::clone(&new)));
753 Ok(new)
754 }};
755}
756
757#[async_trait]
758impl PipelineContext for ComponentManager {
759 async fn storage(&self) -> Result<Arc<dyn StorageTrait>, ComponentError> {
760 versioned_accessor!(self, storage, init_storage)
761 }
762
763 async fn database(&self) -> Result<Arc<DatabaseConnection>, ComponentError> {
764 versioned_accessor!(self, database, init_database)
765 }
766
767 async fn graph_db(&self) -> Result<Arc<dyn GraphDBTrait>, ComponentError> {
768 versioned_accessor!(self, graph_db, init_graph_db)
769 }
770
771 async fn vector_db(&self) -> Result<Arc<dyn VectorDB>, ComponentError> {
772 versioned_accessor!(self, vector_db, init_vector_db)
773 }
774
775 async fn embedding_engine(&self) -> Result<Arc<dyn EmbeddingEngine>, ComponentError> {
776 versioned_accessor!(self, embedding_engine, init_embedding_engine)
777 }
778
779 async fn llm(&self) -> Result<Arc<dyn Llm>, ComponentError> {
780 versioned_accessor!(self, llm, init_llm)
781 }
782}
783
784#[cfg(test)]
785#[allow(
786 clippy::unwrap_used,
787 clippy::expect_used,
788 reason = "test code — panics are acceptable failures"
789)]
790mod tests {
791 use super::*;
792 use crate::config::{ConfigManager, Settings};
793
794 fn cm_with_provider(provider: &str) -> ComponentManager {
795 let settings = Settings {
796 llm_provider: provider.to_string(),
797 llm_api_key: "sk-test".to_string(),
798 llm_model: "gpt-4o-mini".to_string(),
799 ..Settings::default()
800 };
801 ComponentManager::new(ConfigManager::new(settings))
802 }
803
804 #[tokio::test]
805 async fn transcriber_returns_some_for_openai() {
806 let cm = cm_with_provider("openai");
807 let result = cm
808 .transcriber()
809 .await
810 .expect("transcriber() should not error");
811 assert!(
812 result.is_some(),
813 "openai provider must yield Some(transcriber)"
814 );
815 }
816
817 #[tokio::test]
818 async fn transcriber_returns_none_for_unknown_provider() {
819 let settings = Settings {
821 llm_provider: "mock".to_string(),
822 llm_api_key: String::new(),
823 ..Settings::default()
824 };
825 let cm = ComponentManager::new(ConfigManager::new(settings));
826 let result = cm
827 .transcriber()
828 .await
829 .expect("transcriber() should not error for mock");
830 assert!(result.is_none(), "non-openai provider must yield None");
831 }
832
833 #[tokio::test]
834 async fn transcriber_is_cached_across_calls() {
835 let cm = cm_with_provider("openai");
836 let first = cm.transcriber().await.expect("first call").unwrap();
837 let second = cm.transcriber().await.expect("second call").unwrap();
838 assert!(Arc::ptr_eq(&first, &second), "transcriber should be cached");
840 }
841
842 #[cfg(feature = "pggraph")]
845 fn cm_with_graph_settings(settings: Settings) -> ComponentManager {
846 ComponentManager::new(ConfigManager::new(settings))
847 }
848
849 #[cfg(feature = "pggraph")]
850 #[test]
851 fn resolved_graph_db_url_returns_explicit_url_as_is() {
852 let settings = Settings {
853 graph_database_url: "postgres://user:pw@myhost:5432/graphs".to_string(),
854 ..Settings::default()
855 };
856 let cm = cm_with_graph_settings(settings.clone());
857 let url = cm
858 .resolved_graph_db_url(&settings)
859 .expect("should succeed with full URL");
860 assert_eq!(url, "postgres://user:pw@myhost:5432/graphs");
861 }
862
863 #[cfg(feature = "pggraph")]
864 #[test]
865 fn resolved_graph_db_url_builds_from_graph_creds() {
866 let settings = Settings {
867 graph_database_host: "graphhost".to_string(),
868 graph_database_port: 5432,
869 graph_database_name: "mygraph".to_string(),
870 graph_database_username: "guser".to_string(),
871 graph_database_password: "gpass".to_string(),
872 ..Settings::default()
873 };
874 let cm = cm_with_graph_settings(settings.clone());
875 let url = cm
876 .resolved_graph_db_url(&settings)
877 .expect("should build from graph creds");
878 assert!(url.contains("guser"), "URL should contain username");
879 assert!(url.contains("graphhost"), "URL should contain host");
880 assert!(url.contains("mygraph"), "URL should contain db name");
881 }
882
883 #[cfg(feature = "pggraph")]
884 #[test]
885 fn resolved_graph_db_url_falls_back_to_relational_creds() {
886 let settings = Settings {
888 db_host: "relhost".to_string(),
889 db_port: 5432,
890 db_name: "reldb".to_string(),
891 db_username: "reluser".to_string(),
892 db_password: "relpass".to_string(),
893 ..Settings::default()
894 };
895 let cm = cm_with_graph_settings(settings.clone());
896 let url = cm
897 .resolved_graph_db_url(&settings)
898 .expect("should fall back to relational creds");
899 assert!(
900 url.contains("reluser"),
901 "URL should contain relational username"
902 );
903 assert!(
904 url.contains("relhost"),
905 "URL should contain relational host"
906 );
907 assert!(
908 url.contains("reldb"),
909 "URL should contain relational db name"
910 );
911 }
912
913 #[cfg(feature = "pggraph")]
914 #[test]
915 fn resolved_graph_db_url_errors_when_no_creds() {
916 let settings = Settings {
918 db_host: String::new(),
919 db_name: String::new(),
920 db_username: String::new(),
921 ..Settings::default()
922 };
923 let cm = cm_with_graph_settings(settings.clone());
924 let result = cm.resolved_graph_db_url(&settings);
925 assert!(result.is_err(), "should error when no creds available");
926 }
927
928 #[tokio::test]
929 async fn init_graph_db_rejects_unsupported_provider() {
930 let settings = Settings {
931 graph_database_provider: "neo4j".to_string(),
932 ..Settings::default()
933 };
934 let cm = ComponentManager::new(ConfigManager::new(settings));
935 let result = cm.graph_db().await;
936 assert!(result.is_err());
937 let err_msg = match result {
938 Err(e) => e.to_string(),
939 Ok(_) => panic!("expected error"),
940 };
941 assert!(
942 err_msg.contains("postgres"),
943 "error message should list 'postgres' as supported: {err_msg}"
944 );
945 }
946
947 #[cfg(feature = "mock-llm")]
951 fn write_cassette() -> (tempfile::TempDir, std::path::PathBuf) {
952 let dir = tempfile::tempdir().expect("tempdir");
953 let path = dir.path().join("cassette.json");
954 let body = r#"{"version":1,"model":"mock-model","entries":{}}"#;
957 std::fs::write(&path, body).expect("write cassette");
958 (dir, path)
959 }
960
961 #[cfg(feature = "mock-llm")]
962 #[tokio::test]
963 async fn init_llm_uses_replay_mock_when_llm_mock_set_without_api_key() {
964 let (_dir, cassette) = write_cassette();
965 let settings = Settings {
967 llm_mock: true,
968 llm_cassette: cassette.to_string_lossy().into_owned(),
969 llm_api_key: String::new(),
970 ..Settings::default()
971 };
972 let cm = ComponentManager::new(ConfigManager::new(settings));
973 let llm = cm.llm().await.expect("mock llm should initialize offline");
974 assert_eq!(
975 llm.model(),
976 "mock-model",
977 "replay mock reports cassette model"
978 );
979 let resp = llm
981 .generate(
982 vec![cognee_llm::Message {
983 role: cognee_llm::MessageRole::User,
984 content: "hello".to_string(),
985 }],
986 None,
987 )
988 .await
989 .expect("offline generate should succeed");
990 assert_eq!(resp.model, "mock-model");
991 }
992
993 #[cfg(feature = "mock-llm")]
994 #[tokio::test]
995 async fn init_llm_selects_mock_when_provider_is_mock() {
996 let (_dir, cassette) = write_cassette();
997 let settings = Settings {
998 llm_provider: "mock".to_string(),
999 llm_cassette: cassette.to_string_lossy().into_owned(),
1000 llm_api_key: String::new(),
1001 ..Settings::default()
1002 };
1003 let cm = ComponentManager::new(ConfigManager::new(settings));
1004 let llm = cm
1005 .llm()
1006 .await
1007 .expect("provider=mock should initialize offline");
1008 assert_eq!(llm.model(), "mock-model");
1009 }
1010
1011 #[cfg(feature = "mock-llm")]
1012 #[tokio::test]
1013 async fn init_llm_errors_when_mock_set_but_cassette_empty() {
1014 let settings = Settings {
1015 llm_mock: true,
1016 llm_cassette: String::new(),
1017 ..Settings::default()
1018 };
1019 let cm = ComponentManager::new(ConfigManager::new(settings));
1020 let err = match cm.llm().await {
1021 Err(e) => e,
1022 Ok(_) => panic!("empty cassette must error"),
1023 };
1024 assert!(
1025 err.to_string().contains("MOCK_LLM_CASSETTE"),
1026 "error should mention the missing cassette env: {err}"
1027 );
1028 }
1029
1030 #[cfg(feature = "mock-llm")]
1031 #[tokio::test]
1032 async fn init_llm_wraps_real_adapter_in_recorder_when_record_path_set() {
1033 let dir = tempfile::tempdir().expect("tempdir");
1034 let record_path = dir.path().join("recorded.json");
1035 let settings = Settings {
1037 llm_provider: "openai".to_string(),
1038 llm_api_key: "sk-test".to_string(),
1039 llm_model: "gpt-4o-mini".to_string(),
1040 llm_record_path: record_path.to_string_lossy().into_owned(),
1041 ..Settings::default()
1042 };
1043 let cm = ComponentManager::new(ConfigManager::new(settings));
1044 let llm = cm
1047 .llm()
1048 .await
1049 .expect("recording wrap should initialize without network");
1050 assert_eq!(
1051 llm.model(),
1052 "gpt-4o-mini",
1053 "recorder delegates model() to the wrapped adapter"
1054 );
1055 }
1056}