Skip to main content

cognee_lib/
component_manager.rs

1//! ComponentManager: lazy-initializing, shared component store.
2
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use tokio::sync::RwLock as TokioRwLock;
8
9use cognee_database::{DatabaseConnection, connect, initialize};
10use cognee_embedding::{EmbeddingConfig, EmbeddingEngine, EmbeddingProvider};
11use cognee_graph::GraphDBTrait;
12#[cfg(feature = "ladybug")]
13use cognee_graph::LadybugAdapter;
14#[cfg(feature = "pggraph")]
15use cognee_graph::PgGraphAdapter;
16use cognee_llm::{Llm, OpenAIAdapter, Transcriber};
17use cognee_storage::{LocalStorage, StorageTrait};
18#[cfg(feature = "pgvector")]
19use cognee_vector::PgVectorAdapter;
20use cognee_vector::{BruteForceVectorDB, VectorDB};
21
22use crate::config::{ConfigManager, Settings};
23use crate::context::PipelineContext;
24use crate::error::ComponentError;
25
26/// Assemble a `postgres://user:pass@host:port/dbname` URL with percent-encoded
27/// credentials. Shared by the vector and graph URL resolvers.
28#[cfg(any(feature = "pgvector", feature = "pggraph"))]
29fn build_postgres_url(
30    host: &str,
31    port: u16,
32    name: &str,
33    user: &str,
34    pass: &str,
35) -> Result<String, String> {
36    #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
37    let mut parsed = url::Url::parse("postgres://localhost").expect("static URL is always valid");
38    parsed
39        .set_host(Some(host))
40        .map_err(|e| format!("invalid host '{host}': {e}"))?;
41    parsed
42        .set_port(Some(port))
43        .map_err(|_| format!("invalid port {port}"))?;
44    parsed.set_path(&format!("/{name}"));
45    parsed
46        .set_username(user)
47        .map_err(|_| format!("invalid username '{user}'"))?;
48    parsed
49        .set_password(Some(pass))
50        .map_err(|_| "invalid password".to_string())?;
51    Ok(parsed.to_string())
52}
53
54/// Manages shared, lazily-initialized pipeline components.
55///
56/// Each component is created on first access and cached for subsequent calls.
57/// When the underlying [`ConfigManager`]'s version advances (due to a setter
58/// call), cached components are lazily re-created on the next access.
59///
60/// Constructed from [`ConfigManager`] — typically loaded once via
61/// `ConfigManager::from_env()` or `ConfigManager::new(settings)`.
62pub struct ComponentManager {
63    config: ConfigManager,
64    // Each cached component stores (version_at_creation, component_arc).
65    // When the config version advances past the cached version, the
66    // component is lazily re-created on next access.
67    storage: TokioRwLock<Option<(u64, Arc<dyn StorageTrait>)>>,
68    database: TokioRwLock<Option<(u64, Arc<DatabaseConnection>)>>,
69    graph_db: TokioRwLock<Option<(u64, Arc<dyn GraphDBTrait>)>>,
70    vector_db: TokioRwLock<Option<(u64, Arc<dyn VectorDB>)>>,
71    embedding_engine: TokioRwLock<Option<(u64, Arc<dyn EmbeddingEngine>)>>,
72    llm: TokioRwLock<Option<(u64, Arc<dyn Llm>)>>,
73    // Stores Option<Arc<dyn Transcriber>>: None when the provider does not
74    // support transcription (e.g. litert). The outer Option<(ver, ...)> is
75    // the version-keyed cache envelope.
76    #[allow(clippy::type_complexity)]
77    transcriber: TokioRwLock<Option<(u64, Option<Arc<dyn Transcriber>>)>>,
78}
79
80impl ComponentManager {
81    pub fn new(config: ConfigManager) -> Self {
82        Self {
83            config,
84            storage: TokioRwLock::new(None),
85            database: TokioRwLock::new(None),
86            graph_db: TokioRwLock::new(None),
87            vector_db: TokioRwLock::new(None),
88            embedding_engine: TokioRwLock::new(None),
89            llm: TokioRwLock::new(None),
90            transcriber: TokioRwLock::new(None),
91        }
92    }
93
94    /// Read-only snapshot of current settings.
95    ///
96    /// Returns a `RwLockReadGuard` that auto-derefs to `&Settings`.
97    /// Most call sites that use `cm.settings().field_name` work unchanged.
98    pub fn settings(&self) -> std::sync::RwLockReadGuard<'_, Settings> {
99        self.config.read()
100    }
101
102    /// Access the underlying [`ConfigManager`] for runtime mutation.
103    pub fn config(&self) -> &ConfigManager {
104        &self.config
105    }
106
107    async fn init_storage(&self) -> Result<Arc<dyn StorageTrait>, ComponentError> {
108        let data_root = self.config.read().data_root_directory.clone();
109        let storage = LocalStorage::new(PathBuf::from(&data_root));
110        storage
111            .initialize()
112            .await
113            .map_err(|e| ComponentError::Storage(format!("initialization failed: {e}")))?;
114        Ok(Arc::new(storage))
115    }
116
117    async fn init_database(&self) -> Result<Arc<DatabaseConnection>, ComponentError> {
118        let url = self.config.read().resolved_relational_db_url();
119
120        // For SQLite file-backed databases, ensure the parent directory exists
121        // before handing the URL to sea-orm.  sea-orm's `?mode=rwc` creates the
122        // *file* but not missing ancestor directories, so without this step any
123        // settings override that redirects the DB to a new path (e.g. per-test
124        // isolation) would fail with "unable to open database file".
125        //
126        // URL shapes we handle:
127        //   sqlite:./rel/path/db       (relative, 1-slash)
128        //   sqlite:///abs/path/db      (absolute, 3-slash)
129        //   sqlite://localhost/abs/db  (host form)
130        // All others (postgres, in-memory `sqlite::memory:`) are left alone.
131        if url.starts_with("sqlite:") && !url.contains(":memory:") {
132            // Strip the sqlite: scheme and any leading host ("//localhost") or
133            // extra slashes to get the raw filesystem path (before '?').
134            let after_scheme = url.trim_start_matches("sqlite:");
135            let path_part = if after_scheme.starts_with("//localhost/") {
136                Some(&after_scheme["//localhost".len()..])
137            } else if after_scheme.starts_with("///") {
138                // sqlite:///abs/path — empty authority, absolute path.
139                Some(&after_scheme[2..])
140            } else if after_scheme.starts_with("//") {
141                // sqlite://somehost/... — genuine host form; leave entirely to
142                // the driver instead of attempting create_dir_all("//somehost").
143                None
144            } else {
145                Some(after_scheme)
146            };
147            // Drop query string (e.g. ?mode=rwc).
148            if let Some(path_part) = path_part {
149                let path_no_query = path_part.split('?').next().unwrap_or(path_part);
150                let db_path = Path::new(path_no_query);
151                if let Some(parent) = db_path.parent()
152                    && !parent.as_os_str().is_empty()
153                {
154                    // Non-fatal: an unusual-but-driver-valid URL must still
155                    // reach sea-orm and surface the driver's own error.
156                    if let Err(e) = std::fs::create_dir_all(parent) {
157                        tracing::warn!(
158                            "could not create SQLite parent directory '{}': {e}",
159                            parent.display()
160                        );
161                    }
162                }
163            }
164        }
165
166        let db = connect(&url)
167            .await
168            .map_err(|e| ComponentError::Database(format!("initialization failed: {e}")))?;
169        initialize(&db)
170            .await
171            .map_err(|e| ComponentError::Database(format!("schema initialization failed: {e}")))?;
172        Ok(Arc::new(db))
173    }
174
175    async fn init_graph_db(&self) -> Result<Arc<dyn GraphDBTrait>, ComponentError> {
176        let provider = self.config.read().graph_database_provider.to_lowercase();
177
178        match provider.as_str() {
179            "ladybug" | "kuzu" => self.init_ladybug_graph_db().await,
180
181            #[cfg(feature = "pggraph")]
182            "postgres" | "postgresql" => {
183                let url = {
184                    let s = self.config.read();
185                    self.resolved_graph_db_url(&s)?
186                };
187                let adapter = PgGraphAdapter::new(&url)
188                    .await
189                    .map_err(|e| ComponentError::GraphDb(format!("pggraph init failed: {e}")))?;
190                Ok(Arc::new(adapter))
191            }
192
193            #[cfg(not(feature = "pggraph"))]
194            "postgres" | "postgresql" => Err(ComponentError::Config(
195                "graph_database_provider=postgres requires the `pggraph` crate feature".into(),
196            )),
197
198            other => Err(ComponentError::Config(format!(
199                "Unsupported graph_database_provider '{other}'. Supported: ladybug, kuzu, postgres.",
200            ))),
201        }
202    }
203
204    async fn init_ladybug_graph_db(&self) -> Result<Arc<dyn GraphDBTrait>, ComponentError> {
205        let graph_path = {
206            let s = self.config.read();
207            if !s.graph_file_path.is_empty() {
208                s.graph_file_path.clone()
209            } else {
210                format!("{}/graph", s.system_root_directory)
211            }
212        };
213
214        if let Some(parent) = Path::new(&graph_path).parent() {
215            std::fs::create_dir_all(parent)?;
216        }
217
218        #[cfg(feature = "ladybug")]
219        {
220            let graph_db = LadybugAdapter::new(&graph_path)
221                .await
222                .map_err(|e| ComponentError::GraphDb(format!("initialization failed: {e}")))?;
223            graph_db.initialize().await.map_err(|e| {
224                ComponentError::GraphDb(format!("schema initialization failed: {e}"))
225            })?;
226            Ok(Arc::new(graph_db))
227        }
228
229        #[cfg(not(feature = "ladybug"))]
230        Err(ComponentError::Config(
231            "graph_database_provider=ladybug requires the `ladybug` crate feature".to_string(),
232        ))
233    }
234
235    /// Build a Postgres connection URL from the graph_database_* settings,
236    /// falling back to the relational db_* fields when graph-specific creds
237    /// are not fully configured (Python `get_graph_engine.py:332-367` parity).
238    ///
239    /// Precedence:
240    /// 1. `graph_database_url` already looks like a full `postgres://` URL → return as-is.
241    /// 2. All graph-specific fields are set (username, password, host, port, name) → build from those.
242    /// 3. Fall back to the relational `db_*` fields with a warning.
243    /// 4. Neither complete → error.
244    #[cfg(feature = "pggraph")]
245    fn resolved_graph_db_url(&self, s: &Settings) -> Result<String, ComponentError> {
246        if s.graph_database_url.starts_with("postgres://")
247            || s.graph_database_url.starts_with("postgresql://")
248        {
249            return Ok(s.graph_database_url.clone());
250        }
251
252        let graph_host = if s.graph_database_host.is_empty() {
253            None
254        } else {
255            Some(s.graph_database_host.as_str())
256        };
257
258        let graph_creds_complete = graph_host.is_some()
259            && !s.graph_database_username.is_empty()
260            && !s.graph_database_name.is_empty();
261
262        let (host, port, name, user, pass) = if graph_creds_complete {
263            (
264                #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
265                graph_host.expect("checked above"),
266                s.graph_database_port,
267                s.graph_database_name.as_str(),
268                s.graph_database_username.as_str(),
269                s.graph_database_password.as_str(),
270            )
271        } else {
272            tracing::warn!(
273                "Postgres graph credentials not fully configured; falling back to the \
274                 relational database configuration. Set GRAPH_DATABASE_* explicitly to avoid this."
275            );
276            if s.db_host.is_empty() || s.db_name.is_empty() || s.db_username.is_empty() {
277                return Err(ComponentError::Config(
278                    "Missing required Postgres graph credentials".into(),
279                ));
280            }
281            (
282                s.db_host.as_str(),
283                s.db_port,
284                s.db_name.as_str(),
285                s.db_username.as_str(),
286                s.db_password.as_str(),
287            )
288        };
289
290        build_postgres_url(host, port, name, user, pass)
291            .map_err(|e| ComponentError::Config(format!("failed to build graph DB URL: {e}")))
292    }
293
294    async fn init_vector_db(&self) -> Result<Arc<dyn VectorDB>, ComponentError> {
295        // Clone all needed fields out of the read guard before any await.
296        let (provider, _dim) = {
297            let s = self.config.read();
298            (
299                s.vector_db_provider.to_lowercase(),
300                s.embedding_dimensions as usize,
301            )
302        };
303
304        match provider.as_str() {
305            "pgvector" => {
306                #[cfg(feature = "pgvector")]
307                {
308                    let url = {
309                        let s = self.config.read();
310                        self.resolved_vector_db_url(&s)?
311                    };
312                    let adapter = PgVectorAdapter::new(&url, _dim).await.map_err(|e| {
313                        ComponentError::VectorDb(format!("pgvector init failed: {e}"))
314                    })?;
315                    Ok(Arc::new(adapter))
316                }
317
318                #[cfg(not(feature = "pgvector"))]
319                Err(ComponentError::Config(
320                    "vector_db_provider=pgvector requires the `pgvector` crate feature".to_string(),
321                ))
322            }
323            // Pure-Rust in-memory brute-force backend (Android default + ":memory:" escape hatch).
324            "brute-force" | "brute_force" | "bruteforce" => Ok(Arc::new(BruteForceVectorDB::new())),
325            // Embedded LanceDB on non-Android targets; falls back to brute-force
326            // on Android (LanceDB + Arrow native stack does not cross-compile
327            // there). Honours `vector_db_url = ":memory:"` as an explicit
328            // brute-force opt-in for ephemeral / test workloads.
329            "lancedb" => {
330                let url = {
331                    let s = self.config.read();
332                    s.vector_db_url.clone()
333                };
334                if url == ":memory:" {
335                    return Ok(Arc::new(BruteForceVectorDB::new()));
336                }
337                #[cfg(not(target_os = "android"))]
338                {
339                    let path = self.resolved_lancedb_path(&url);
340                    let adapter = cognee_vector::LanceDbAdapter::new(path)
341                        .await
342                        .map_err(|e| {
343                            ComponentError::VectorDb(format!("lancedb init failed: {e}"))
344                        })?;
345                    Ok(Arc::new(adapter))
346                }
347                #[cfg(target_os = "android")]
348                {
349                    tracing::warn!(
350                        "vector_db_provider='lancedb' is not available on Android; \
351                         falling back to in-memory brute-force. Set \
352                         vector_db_provider='pgvector' for production durable storage."
353                    );
354                    Ok(Arc::new(BruteForceVectorDB::new()))
355                }
356            }
357            // the closed split extracted the Qdrant adapter out of OSS into the closed
358            // `cognee-vector-qdrant` crate. OSS hard-errors rather than silently
359            // substituting a different backend, mirroring the closed-adapter
360            // handling elsewhere in this file (e.g. LiteRT). LanceDB stays OSS
361            // (handled above); only Qdrant is closed.
362            "qdrant" => Err(ComponentError::Config(format!(
363                "vector_db_provider='{provider}' is not available in this build. \
364                 The Qdrant adapter has been extracted to the closed \
365                 cognee-vector-qdrant crate. Use vector_db_provider='pgvector', \
366                 'lancedb', or 'brute-force' in OSS."
367            ))),
368            #[cfg(feature = "testing")]
369            "mock" => Ok(Arc::new(cognee_vector::MockVectorDB::new())),
370            other => Err(ComponentError::Config(format!(
371                "Unsupported vector_db_provider '{other}'. \
372                 Supported: pgvector, lancedb (non-Android), brute-force, mock (testing feature only).",
373            ))),
374        }
375    }
376
377    /// Resolve the on-disk path for the LanceDB store.
378    ///
379    /// Honours an explicit `vector_db_url` when set; otherwise defaults to
380    /// `{system_root_directory}/databases/cognee.lancedb` — matching the
381    /// Python SDK file layout so a Rust deployment can be opened from
382    /// Python and vice versa.
383    #[cfg(not(target_os = "android"))]
384    fn resolved_lancedb_path(&self, vector_db_url: &str) -> std::path::PathBuf {
385        use std::path::PathBuf;
386        if !vector_db_url.is_empty() {
387            return PathBuf::from(vector_db_url);
388        }
389        let root = {
390            let s = self.config.read();
391            s.system_root_directory.clone()
392        };
393        PathBuf::from(root).join("databases").join("cognee.lancedb")
394    }
395
396    /// Build a Postgres connection URL from the vector_db_* settings.
397    ///
398    /// If `vector_db_url` already looks like a full `postgres://` URL it is
399    /// returned as-is. Otherwise the URL is assembled from the individual
400    /// `vector_db_*` / `db_*` fields using the `url` crate so that special
401    /// characters in passwords are percent-encoded correctly.
402    #[cfg(feature = "pgvector")]
403    fn resolved_vector_db_url(&self, settings: &Settings) -> Result<String, ComponentError> {
404        if settings.vector_db_url.starts_with("postgres://")
405            || settings.vector_db_url.starts_with("postgresql://")
406        {
407            return Ok(settings.vector_db_url.clone());
408        }
409
410        let host = if settings.vector_db_url.is_empty() {
411            "localhost"
412        } else {
413            &settings.vector_db_url
414        };
415        let port = settings.vector_db_port;
416        let name = if settings.vector_db_name.is_empty() {
417            "cognee_vectors"
418        } else {
419            &settings.vector_db_name
420        };
421        let user = if settings.db_username.is_empty() {
422            "postgres"
423        } else {
424            &settings.db_username
425        };
426        let pass = &settings.db_password;
427
428        build_postgres_url(host, port, name, user, pass)
429            .map_err(|e| ComponentError::Config(format!("failed to build vector DB URL: {e}")))
430    }
431
432    /// Initialize the embedding engine from Settings fields instead of
433    /// calling `EmbeddingConfig::from_env()` directly.
434    ///
435    /// This ensures that runtime config changes via `ConfigManager` flow
436    /// through to the embedding engine.
437    async fn init_embedding_engine(&self) -> Result<Arc<dyn EmbeddingEngine>, ComponentError> {
438        // Build EmbeddingConfig from Settings inside a block so the guard
439        // is dropped before any .await point.
440        let mut config = {
441            let settings = self.config.read();
442
443            // Map Settings.embedding_provider string to EmbeddingProvider enum.
444            let provider_str = settings.embedding_provider.trim().to_lowercase();
445            let provider = match provider_str.as_str() {
446                "onnx" => EmbeddingProvider::Onnx,
447                "fastembed" => EmbeddingProvider::Fastembed,
448                "openai" => EmbeddingProvider::OpenAi,
449                "openai_compatible" => EmbeddingProvider::OpenAiCompatible,
450                "ollama" => EmbeddingProvider::Ollama,
451                "mock" => EmbeddingProvider::Mock,
452                _ => EmbeddingProvider::Onnx,
453            };
454
455            // Endpoint/key fall back to the LLM provider's when no embedding-
456            // specific values are set. The default embedding provider is OpenAI
457            // (off-Android), and cognee typically shares one OpenAI-compatible
458            // account for chat + embeddings, so this makes the default work with
459            // just OPENAI_URL/OPENAI_TOKEN (→ llm_endpoint/llm_api_key) — no
460            // separate EMBEDDING_ENDPOINT/EMBEDDING_API_KEY required.
461            let endpoint = [&settings.embedding_endpoint, &settings.llm_endpoint]
462                .into_iter()
463                .find(|v| !v.is_empty())
464                .cloned();
465
466            let api_key = [&settings.embedding_api_key, &settings.llm_api_key]
467                .into_iter()
468                .find(|v| !v.is_empty())
469                .cloned();
470
471            // Check MOCK_EMBEDDING env var as a fallback (preserves backward compat).
472            // `deterministic`/`hash` selects SHA-256-derived vectors; other truthy
473            // values keep the legacy zero-vector mode.
474            let mock_mode = std::env::var("MOCK_EMBEDDING")
475                .ok()
476                .map(|v| v.trim().to_lowercase());
477            let mock_deterministic =
478                matches!(mock_mode.as_deref(), Some("deterministic") | Some("hash"));
479            let mock = mock_deterministic
480                || matches!(mock_mode.as_deref(), Some("true") | Some("1") | Some("yes"));
481            let mock_mode = if mock_deterministic {
482                cognee_embedding::MockVectorMode::Deterministic
483            } else {
484                cognee_embedding::MockVectorMode::Zero
485            };
486
487            EmbeddingConfig {
488                provider: if mock {
489                    EmbeddingProvider::Mock
490                } else {
491                    provider
492                },
493                model: settings.embedding_model_name.clone(),
494                dimensions: settings.embedding_dimensions as usize,
495                endpoint,
496                api_key,
497                api_version: None,
498                max_completion_tokens: 8191,
499                batch_size: settings.embedding_batch_size as usize,
500                mock,
501                mock_mode,
502                #[cfg(feature = "onnx")]
503                onnx: cognee_embedding::OnnxEmbeddingConfig {
504                    model_path: PathBuf::from(&settings.embedding_model_path),
505                    tokenizer_path: PathBuf::from(&settings.embedding_tokenizer_path),
506                    model_name: settings.embedding_model_name.clone(),
507                    dimensions: settings.embedding_dimensions as usize,
508                    max_sequence_length: settings.embedding_max_sequence_length as usize,
509                    batch_size: settings.embedding_batch_size as usize,
510                },
511                huggingface_tokenizer: None,
512            }
513        };
514        // settings guard is now dropped — safe to await.
515
516        // Still check env vars for fields not yet in Settings (api_version,
517        // huggingface_tokenizer, max_completion_tokens) — forward compatibility.
518        if let Ok(val) = std::env::var("EMBEDDING_API_VERSION") {
519            let val = val.trim().to_string();
520            if !val.is_empty() {
521                config.api_version = Some(val);
522            }
523        }
524        if let Ok(val) = std::env::var("HUGGINGFACE_TOKENIZER") {
525            let val = val.trim().to_string();
526            if !val.is_empty() {
527                config.huggingface_tokenizer = Some(val);
528            }
529        }
530        if let Ok(val) = std::env::var("EMBEDDING_MAX_COMPLETION_TOKENS")
531            && let Ok(n) = val.trim().parse::<usize>()
532        {
533            config.max_completion_tokens = n;
534        }
535
536        config.create_engine().await.map_err(|e| {
537            ComponentError::EmbeddingEngine(format!("embedding engine init failed: {e}"))
538        })
539    }
540
541    async fn init_llm(&self) -> Result<Arc<dyn Llm>, ComponentError> {
542        // Clone all needed fields out of the read guard before any await.
543        let (
544            provider,
545            llm_model,
546            llm_api_key,
547            llm_endpoint,
548            llm_max_retries,
549            llm_mock,
550            llm_cassette,
551            llm_record_path,
552        ) = {
553            let s = self.config.read();
554            (
555                s.llm_provider.to_lowercase(),
556                s.llm_model.clone(),
557                s.llm_api_key.clone(),
558                s.llm_endpoint.clone(),
559                s.llm_max_retries,
560                s.llm_mock,
561                s.llm_cassette.clone(),
562                s.llm_record_path.clone(),
563            )
564        };
565
566        // `llm_cassette` is only consumed on the mock path; silence the
567        // unused-variable lint in builds without the `mock` feature.
568        #[cfg(not(feature = "mock-llm"))]
569        let _ = &llm_cassette;
570
571        // Mock first — like MOCK_EMBEDDING, this overrides the configured
572        // provider. Selected by `MOCK_LLM` (llm_mock) or `llm_provider=mock`.
573        if llm_mock || provider == "mock" {
574            #[cfg(feature = "mock-llm")]
575            {
576                let cassette = llm_cassette.trim();
577                if cassette.is_empty() {
578                    return Err(ComponentError::Config(
579                        "MOCK_LLM is set but MOCK_LLM_CASSETTE is empty; set it to a cassette path"
580                            .to_string(),
581                    ));
582                }
583                let replay = cognee_llm::mock::ReplayLlm::from_path(cassette)
584                    .map_err(|e| ComponentError::Llm(format!("mock cassette load failed: {e}")))?;
585                return Ok(Arc::new(replay));
586            }
587            #[cfg(not(feature = "mock-llm"))]
588            {
589                return Err(ComponentError::Config(
590                    "MOCK_LLM was requested but the mock LLM is unavailable; \
591                     rebuild with the `mock-llm` feature"
592                        .to_string(),
593                ));
594            }
595        }
596
597        // Build the real adapter exactly as before.
598        let adapter: Arc<dyn Llm> = match provider.as_str() {
599            "openai" => {
600                if llm_api_key.is_empty() {
601                    return Err(ComponentError::Config(
602                        "llm_api_key must be configured".to_string(),
603                    ));
604                }
605
606                let endpoint = if llm_endpoint.is_empty() {
607                    None
608                } else {
609                    Some(llm_endpoint)
610                };
611
612                let retries = llm_max_retries.max(1);
613
614                let adapter = OpenAIAdapter::new(llm_model, llm_api_key, endpoint)
615                    .map_err(|e| ComponentError::Llm(format!("initialization failed: {e}")))?
616                    .with_structured_output_retries(retries)
617                    .with_network_retries(retries);
618
619                Arc::new(adapter)
620            }
621            "litert" => {
622                return Err(ComponentError::Config(
623                    "llm_provider=litert is not available in this build. \
624                     The LiteRT adapter has been extracted to the closed cognee-llm-litert crate."
625                        .to_string(),
626                ));
627            }
628            _ => {
629                return Err(ComponentError::Config(format!(
630                    "Unsupported llm_provider '{provider}'. Supported: openai, mock.",
631                )));
632            }
633        };
634
635        // Optional recording wrap (`COGNEE_RECORD_LLM`). Only the real adapter is
636        // worth recording — replaying a recording of a mock is pointless.
637        if !llm_record_path.trim().is_empty() {
638            #[cfg(feature = "mock-llm")]
639            {
640                let recorder = cognee_llm::mock::RecordingLlm::new(
641                    adapter,
642                    llm_record_path.trim().to_string(),
643                );
644                return Ok(Arc::new(recorder));
645            }
646            #[cfg(not(feature = "mock-llm"))]
647            {
648                return Err(ComponentError::Config(
649                    "COGNEE_RECORD_LLM was set but LLM recording is unavailable; \
650                     rebuild with the `mock-llm` feature"
651                        .to_string(),
652                ));
653            }
654        }
655
656        Ok(adapter)
657    }
658
659    async fn init_transcriber(&self) -> Result<Option<Arc<dyn Transcriber>>, ComponentError> {
660        let (provider, llm_model, llm_api_key, llm_endpoint, llm_max_retries) = {
661            let s = self.config.read();
662            (
663                s.llm_provider.to_lowercase(),
664                s.llm_model.clone(),
665                s.llm_api_key.clone(),
666                s.llm_endpoint.clone(),
667                s.llm_max_retries,
668            )
669        };
670
671        match provider.as_str() {
672            "openai" => {
673                if llm_api_key.is_empty() {
674                    return Err(ComponentError::Config(
675                        "llm_api_key must be configured".to_string(),
676                    ));
677                }
678
679                let endpoint = if llm_endpoint.is_empty() {
680                    None
681                } else {
682                    Some(llm_endpoint)
683                };
684
685                let retries = llm_max_retries.max(1);
686
687                let adapter = OpenAIAdapter::new(llm_model, llm_api_key, endpoint)
688                    .map_err(|e| ComponentError::Llm(format!("initialization failed: {e}")))?
689                    .with_structured_output_retries(retries)
690                    .with_network_retries(retries);
691
692                Ok(Some(Arc::new(adapter) as Arc<dyn Transcriber>))
693            }
694            // litert and any future providers that do not implement Transcriber
695            // return None — audio stays gracefully unsupported (D5).
696            _ => Ok(None),
697        }
698    }
699
700    /// Return the [`Transcriber`] for the configured LLM provider, if supported.
701    ///
702    /// Returns `Ok(Some(_))` for OpenAI (Whisper). Returns `Ok(None)` for
703    /// providers that do not support audio transcription (e.g. `litert`), so
704    /// callers can skip registering the `AudioLoader` rather than failing.
705    pub async fn transcriber(&self) -> Result<Option<Arc<dyn Transcriber>>, ComponentError> {
706        let current_ver = self.config.version();
707        // Fast path: read lock
708        {
709            let guard = self.transcriber.read().await;
710            if let Some((ver, ref opt)) = *guard
711                && ver == current_ver
712            {
713                return Ok(opt.clone());
714            }
715        }
716        // Slow path: write lock with double-check
717        let mut guard = self.transcriber.write().await;
718        if let Some((ver, ref opt)) = *guard
719            && ver == current_ver
720        {
721            return Ok(opt.clone());
722        }
723        let new = self.init_transcriber().await?;
724        *guard = Some((current_ver, new.clone()));
725        Ok(new)
726    }
727}
728
729// Versioned accessor helper macro — avoids repeating the double-checked
730// locking pattern for each component.
731macro_rules! versioned_accessor {
732    ($self:ident, $field:ident, $init_fn:ident) => {{
733        let current_ver = $self.config.version();
734        // Fast path: read lock to check cache hit
735        {
736            let guard = $self.$field.read().await;
737            if let Some((ver, ref component)) = *guard {
738                if ver == current_ver {
739                    return Ok(Arc::clone(component));
740                }
741            }
742        }
743        // Slow path: write lock to reinitialize
744        let mut guard = $self.$field.write().await;
745        // Double-check (another task may have reinitialized while we waited)
746        if let Some((ver, ref component)) = *guard {
747            if ver == current_ver {
748                return Ok(Arc::clone(component));
749            }
750        }
751        let new = $self.$init_fn().await?;
752        *guard = Some((current_ver, Arc::clone(&new)));
753        Ok(new)
754    }};
755}
756
757#[async_trait]
758impl PipelineContext for ComponentManager {
759    async fn storage(&self) -> Result<Arc<dyn StorageTrait>, ComponentError> {
760        versioned_accessor!(self, storage, init_storage)
761    }
762
763    async fn database(&self) -> Result<Arc<DatabaseConnection>, ComponentError> {
764        versioned_accessor!(self, database, init_database)
765    }
766
767    async fn graph_db(&self) -> Result<Arc<dyn GraphDBTrait>, ComponentError> {
768        versioned_accessor!(self, graph_db, init_graph_db)
769    }
770
771    async fn vector_db(&self) -> Result<Arc<dyn VectorDB>, ComponentError> {
772        versioned_accessor!(self, vector_db, init_vector_db)
773    }
774
775    async fn embedding_engine(&self) -> Result<Arc<dyn EmbeddingEngine>, ComponentError> {
776        versioned_accessor!(self, embedding_engine, init_embedding_engine)
777    }
778
779    async fn llm(&self) -> Result<Arc<dyn Llm>, ComponentError> {
780        versioned_accessor!(self, llm, init_llm)
781    }
782}
783
784#[cfg(test)]
785#[allow(
786    clippy::unwrap_used,
787    clippy::expect_used,
788    reason = "test code — panics are acceptable failures"
789)]
790mod tests {
791    use super::*;
792    use crate::config::{ConfigManager, Settings};
793
794    fn cm_with_provider(provider: &str) -> ComponentManager {
795        let settings = Settings {
796            llm_provider: provider.to_string(),
797            llm_api_key: "sk-test".to_string(),
798            llm_model: "gpt-4o-mini".to_string(),
799            ..Settings::default()
800        };
801        ComponentManager::new(ConfigManager::new(settings))
802    }
803
804    #[tokio::test]
805    async fn transcriber_returns_some_for_openai() {
806        let cm = cm_with_provider("openai");
807        let result = cm
808            .transcriber()
809            .await
810            .expect("transcriber() should not error");
811        assert!(
812            result.is_some(),
813            "openai provider must yield Some(transcriber)"
814        );
815    }
816
817    #[tokio::test]
818    async fn transcriber_returns_none_for_unknown_provider() {
819        // Any non-openai provider (e.g. "mock") returns None — audio gracefully unsupported.
820        let settings = Settings {
821            llm_provider: "mock".to_string(),
822            llm_api_key: String::new(),
823            ..Settings::default()
824        };
825        let cm = ComponentManager::new(ConfigManager::new(settings));
826        let result = cm
827            .transcriber()
828            .await
829            .expect("transcriber() should not error for mock");
830        assert!(result.is_none(), "non-openai provider must yield None");
831    }
832
833    #[tokio::test]
834    async fn transcriber_is_cached_across_calls() {
835        let cm = cm_with_provider("openai");
836        let first = cm.transcriber().await.expect("first call").unwrap();
837        let second = cm.transcriber().await.expect("second call").unwrap();
838        // Both calls return an Arc pointing to the same allocation.
839        assert!(Arc::ptr_eq(&first, &second), "transcriber should be cached");
840    }
841
842    // -- resolved_graph_db_url / PgGraph provider dispatch --------------------
843
844    #[cfg(feature = "pggraph")]
845    fn cm_with_graph_settings(settings: Settings) -> ComponentManager {
846        ComponentManager::new(ConfigManager::new(settings))
847    }
848
849    #[cfg(feature = "pggraph")]
850    #[test]
851    fn resolved_graph_db_url_returns_explicit_url_as_is() {
852        let settings = Settings {
853            graph_database_url: "postgres://user:pw@myhost:5432/graphs".to_string(),
854            ..Settings::default()
855        };
856        let cm = cm_with_graph_settings(settings.clone());
857        let url = cm
858            .resolved_graph_db_url(&settings)
859            .expect("should succeed with full URL");
860        assert_eq!(url, "postgres://user:pw@myhost:5432/graphs");
861    }
862
863    #[cfg(feature = "pggraph")]
864    #[test]
865    fn resolved_graph_db_url_builds_from_graph_creds() {
866        let settings = Settings {
867            graph_database_host: "graphhost".to_string(),
868            graph_database_port: 5432,
869            graph_database_name: "mygraph".to_string(),
870            graph_database_username: "guser".to_string(),
871            graph_database_password: "gpass".to_string(),
872            ..Settings::default()
873        };
874        let cm = cm_with_graph_settings(settings.clone());
875        let url = cm
876            .resolved_graph_db_url(&settings)
877            .expect("should build from graph creds");
878        assert!(url.contains("guser"), "URL should contain username");
879        assert!(url.contains("graphhost"), "URL should contain host");
880        assert!(url.contains("mygraph"), "URL should contain db name");
881    }
882
883    #[cfg(feature = "pggraph")]
884    #[test]
885    fn resolved_graph_db_url_falls_back_to_relational_creds() {
886        // Graph creds not set, relational creds are set → fallback.
887        let settings = Settings {
888            db_host: "relhost".to_string(),
889            db_port: 5432,
890            db_name: "reldb".to_string(),
891            db_username: "reluser".to_string(),
892            db_password: "relpass".to_string(),
893            ..Settings::default()
894        };
895        let cm = cm_with_graph_settings(settings.clone());
896        let url = cm
897            .resolved_graph_db_url(&settings)
898            .expect("should fall back to relational creds");
899        assert!(
900            url.contains("reluser"),
901            "URL should contain relational username"
902        );
903        assert!(
904            url.contains("relhost"),
905            "URL should contain relational host"
906        );
907        assert!(
908            url.contains("reldb"),
909            "URL should contain relational db name"
910        );
911    }
912
913    #[cfg(feature = "pggraph")]
914    #[test]
915    fn resolved_graph_db_url_errors_when_no_creds() {
916        // Neither graph nor relational creds → config error.
917        let settings = Settings {
918            db_host: String::new(),
919            db_name: String::new(),
920            db_username: String::new(),
921            ..Settings::default()
922        };
923        let cm = cm_with_graph_settings(settings.clone());
924        let result = cm.resolved_graph_db_url(&settings);
925        assert!(result.is_err(), "should error when no creds available");
926    }
927
928    #[tokio::test]
929    async fn init_graph_db_rejects_unsupported_provider() {
930        let settings = Settings {
931            graph_database_provider: "neo4j".to_string(),
932            ..Settings::default()
933        };
934        let cm = ComponentManager::new(ConfigManager::new(settings));
935        let result = cm.graph_db().await;
936        assert!(result.is_err());
937        let err_msg = match result {
938            Err(e) => e.to_string(),
939            Ok(_) => panic!("expected error"),
940        };
941        assert!(
942            err_msg.contains("postgres"),
943            "error message should list 'postgres' as supported: {err_msg}"
944        );
945    }
946
947    // -- Mock LLM factory wiring (MOCK_LLM / COGNEE_RECORD_LLM) ----------------
948
949    /// Write a minimal valid cassette to a temp file and return (dir, path).
950    #[cfg(feature = "mock-llm")]
951    fn write_cassette() -> (tempfile::TempDir, std::path::PathBuf) {
952        let dir = tempfile::tempdir().expect("tempdir");
953        let path = dir.path().join("cassette.json");
954        // A schema-valid, empty cassette: the replay mock falls back to its
955        // default EmptyGraph miss policy, so the pipeline runs with no entries.
956        let body = r#"{"version":1,"model":"mock-model","entries":{}}"#;
957        std::fs::write(&path, body).expect("write cassette");
958        (dir, path)
959    }
960
961    #[cfg(feature = "mock-llm")]
962    #[tokio::test]
963    async fn init_llm_uses_replay_mock_when_llm_mock_set_without_api_key() {
964        let (_dir, cassette) = write_cassette();
965        // No api_key and a non-openai-ready config — the mock must override.
966        let settings = Settings {
967            llm_mock: true,
968            llm_cassette: cassette.to_string_lossy().into_owned(),
969            llm_api_key: String::new(),
970            ..Settings::default()
971        };
972        let cm = ComponentManager::new(ConfigManager::new(settings));
973        let llm = cm.llm().await.expect("mock llm should initialize offline");
974        assert_eq!(
975            llm.model(),
976            "mock-model",
977            "replay mock reports cassette model"
978        );
979        // A generate() call must succeed offline (empty response on cache miss).
980        let resp = llm
981            .generate(
982                vec![cognee_llm::Message {
983                    role: cognee_llm::MessageRole::User,
984                    content: "hello".to_string(),
985                }],
986                None,
987            )
988            .await
989            .expect("offline generate should succeed");
990        assert_eq!(resp.model, "mock-model");
991    }
992
993    #[cfg(feature = "mock-llm")]
994    #[tokio::test]
995    async fn init_llm_selects_mock_when_provider_is_mock() {
996        let (_dir, cassette) = write_cassette();
997        let settings = Settings {
998            llm_provider: "mock".to_string(),
999            llm_cassette: cassette.to_string_lossy().into_owned(),
1000            llm_api_key: String::new(),
1001            ..Settings::default()
1002        };
1003        let cm = ComponentManager::new(ConfigManager::new(settings));
1004        let llm = cm
1005            .llm()
1006            .await
1007            .expect("provider=mock should initialize offline");
1008        assert_eq!(llm.model(), "mock-model");
1009    }
1010
1011    #[cfg(feature = "mock-llm")]
1012    #[tokio::test]
1013    async fn init_llm_errors_when_mock_set_but_cassette_empty() {
1014        let settings = Settings {
1015            llm_mock: true,
1016            llm_cassette: String::new(),
1017            ..Settings::default()
1018        };
1019        let cm = ComponentManager::new(ConfigManager::new(settings));
1020        let err = match cm.llm().await {
1021            Err(e) => e,
1022            Ok(_) => panic!("empty cassette must error"),
1023        };
1024        assert!(
1025            err.to_string().contains("MOCK_LLM_CASSETTE"),
1026            "error should mention the missing cassette env: {err}"
1027        );
1028    }
1029
1030    #[cfg(feature = "mock-llm")]
1031    #[tokio::test]
1032    async fn init_llm_wraps_real_adapter_in_recorder_when_record_path_set() {
1033        let dir = tempfile::tempdir().expect("tempdir");
1034        let record_path = dir.path().join("recorded.json");
1035        // Real openai provider + a record path → wrapped in RecordingLlm.
1036        let settings = Settings {
1037            llm_provider: "openai".to_string(),
1038            llm_api_key: "sk-test".to_string(),
1039            llm_model: "gpt-4o-mini".to_string(),
1040            llm_record_path: record_path.to_string_lossy().into_owned(),
1041            ..Settings::default()
1042        };
1043        let cm = ComponentManager::new(ConfigManager::new(settings));
1044        // Construction must succeed (no network call happens at init time);
1045        // the recorder model() delegates to the wrapped openai adapter.
1046        let llm = cm
1047            .llm()
1048            .await
1049            .expect("recording wrap should initialize without network");
1050        assert_eq!(
1051            llm.model(),
1052            "gpt-4o-mini",
1053            "recorder delegates model() to the wrapped adapter"
1054        );
1055    }
1056}