Skip to main content

recoco_core/
lib_context.rs

1// ReCoco is a Rust-only fork of CocoIndex, by [CocoIndex](https://CocoIndex)
2// Original code from CocoIndex is copyrighted by CocoIndex
3// SPDX-FileCopyrightText: 2025-2026 CocoIndex (upstream)
4// SPDX-FileContributor: CocoIndex Contributors
5//
6// All modifications from the upstream for ReCoco are copyrighted by Knitli Inc.
7// SPDX-FileCopyrightText: 2026 Knitli Inc. (ReCoco)
8// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
9//
10// Both the upstream CocoIndex code and the ReCoco modifications are licensed under the Apache-2.0 License.
11// SPDX-License-Identifier: Apache-2.0
12
13#[cfg(feature = "persistence")]
14use std::time::Duration;
15
16use crate::prelude::*;
17
18use crate::builder::AnalyzedFlow;
19#[cfg(feature = "persistence")]
20use crate::execution::source_indexer::SourceIndexingContext;
21#[cfg(feature = "persistence")]
22use crate::service::query_handler::{QueryHandler, QueryHandlerSpec};
23use crate::settings;
24#[cfg(feature = "persistence")]
25use crate::setup::ObjectSetupChange;
26#[cfg(feature = "server")]
27use axum::http::StatusCode;
28#[cfg(feature = "server")]
29use recoco_utils::error::ApiError;
30#[cfg(feature = "persistence")]
31use sqlx::PgPool;
32#[cfg(feature = "persistence")]
33use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
34use tokio::runtime::Runtime;
35use tracing_subscriber::{EnvFilter, fmt, prelude::*};
36
37#[cfg(feature = "persistence")]
38pub struct FlowExecutionContext {
39    pub setup_execution_context: Arc<exec_ctx::FlowSetupExecutionContext>,
40    pub setup_change: setup::FlowSetupChange,
41    source_indexing_contexts: Vec<tokio::sync::OnceCell<Arc<SourceIndexingContext>>>,
42}
43
44#[cfg(feature = "persistence")]
45async fn build_setup_context(
46    analyzed_flow: &AnalyzedFlow,
47    existing_flow_ss: Option<&setup::FlowSetupState<setup::ExistingMode>>,
48) -> Result<(
49    Arc<exec_ctx::FlowSetupExecutionContext>,
50    setup::FlowSetupChange,
51)> {
52    let setup_execution_context = Arc::new(exec_ctx::build_flow_setup_execution_context(
53        &analyzed_flow.flow_instance,
54        &analyzed_flow.data_schema,
55        &analyzed_flow.setup_state,
56        existing_flow_ss,
57    )?);
58
59    let setup_change = setup::diff_flow_setup_states(
60        Some(&setup_execution_context.setup_state),
61        existing_flow_ss,
62        &analyzed_flow.flow_instance_ctx,
63    )
64    .await?;
65
66    Ok((setup_execution_context, setup_change))
67}
68
69#[cfg(feature = "persistence")]
70impl FlowExecutionContext {
71    async fn new(
72        analyzed_flow: &AnalyzedFlow,
73        existing_flow_ss: Option<&setup::FlowSetupState<setup::ExistingMode>>,
74    ) -> Result<Self> {
75        let (setup_execution_context, setup_change) =
76            build_setup_context(analyzed_flow, existing_flow_ss).await?;
77
78        let mut source_indexing_contexts = Vec::new();
79        source_indexing_contexts.resize_with(analyzed_flow.flow_instance.import_ops.len(), || {
80            tokio::sync::OnceCell::new()
81        });
82
83        Ok(Self {
84            setup_execution_context,
85            setup_change,
86            source_indexing_contexts,
87        })
88    }
89
90    pub async fn update_setup_state(
91        &mut self,
92        analyzed_flow: &AnalyzedFlow,
93        existing_flow_ss: Option<&setup::FlowSetupState<setup::ExistingMode>>,
94    ) -> Result<()> {
95        let (setup_execution_context, setup_change) =
96            build_setup_context(analyzed_flow, existing_flow_ss).await?;
97
98        self.setup_execution_context = setup_execution_context;
99        self.setup_change = setup_change;
100        Ok(())
101    }
102
103    pub async fn get_source_indexing_context(
104        &self,
105        flow: &Arc<AnalyzedFlow>,
106        source_idx: usize,
107        pool: &PgPool,
108    ) -> Result<&Arc<SourceIndexingContext>> {
109        self.source_indexing_contexts[source_idx]
110            .get_or_try_init(|| async move {
111                SourceIndexingContext::load(
112                    flow.clone(),
113                    source_idx,
114                    self.setup_execution_context.clone(),
115                    pool,
116                )
117                .await
118            })
119            .await
120    }
121}
122
123#[cfg(feature = "persistence")]
124pub struct QueryHandlerContext {
125    pub info: Arc<QueryHandlerSpec>,
126    pub handler: Arc<dyn QueryHandler>,
127}
128
129pub struct FlowContext {
130    pub flow: Arc<AnalyzedFlow>,
131    #[cfg(feature = "persistence")]
132    execution_ctx: Arc<tokio::sync::RwLock<FlowExecutionContext>>,
133    #[cfg(feature = "persistence")]
134    pub query_handlers: RwLock<HashMap<String, QueryHandlerContext>>,
135}
136
137impl FlowContext {
138    pub fn flow_name(&self) -> &str {
139        &self.flow.flow_instance.name
140    }
141
142    #[cfg(feature = "persistence")]
143    pub async fn new(
144        flow: Arc<AnalyzedFlow>,
145        existing_flow_ss: Option<&setup::FlowSetupState<setup::ExistingMode>>,
146    ) -> Result<Self> {
147        let execution_ctx = Arc::new(tokio::sync::RwLock::new(
148            FlowExecutionContext::new(&flow, existing_flow_ss).await?,
149        ));
150        Ok(Self {
151            flow,
152            execution_ctx,
153            query_handlers: RwLock::new(HashMap::new()),
154        })
155    }
156
157    #[cfg(not(feature = "persistence"))]
158    pub fn new_transient(flow: Arc<AnalyzedFlow>) -> Self {
159        Self { flow }
160    }
161
162    #[cfg(feature = "persistence")]
163    pub async fn use_execution_ctx(
164        &self,
165    ) -> Result<tokio::sync::RwLockReadGuard<'_, FlowExecutionContext>> {
166        let execution_ctx = self.execution_ctx.read().await;
167        if !execution_ctx.setup_change.is_up_to_date() {
168            api_bail!(
169                "Setup for flow `{}` is not up-to-date. Please run `cocoindex setup` to update the setup.",
170                self.flow_name()
171            );
172        }
173        Ok(execution_ctx)
174    }
175
176    #[cfg(feature = "persistence")]
177    pub async fn use_owned_execution_ctx(
178        &self,
179    ) -> Result<tokio::sync::OwnedRwLockReadGuard<FlowExecutionContext>> {
180        let execution_ctx = self.execution_ctx.clone().read_owned().await;
181        if !execution_ctx.setup_change.is_up_to_date() {
182            api_bail!(
183                "Setup for flow `{}` is not up-to-date. Please run `cocoindex setup` to update the setup.",
184                self.flow_name()
185            );
186        }
187        Ok(execution_ctx)
188    }
189
190    #[cfg(feature = "persistence")]
191    pub fn get_execution_ctx_for_setup(&self) -> &tokio::sync::RwLock<FlowExecutionContext> {
192        &self.execution_ctx
193    }
194}
195
196static TOKIO_RUNTIME: LazyLock<Runtime> = LazyLock::new(|| Runtime::new().unwrap());
197static AUTH_REGISTRY: LazyLock<Arc<AuthRegistry>> = LazyLock::new(|| Arc::new(AuthRegistry::new()));
198
199pub fn get_runtime() -> &'static Runtime {
200    &TOKIO_RUNTIME
201}
202pub fn get_auth_registry() -> &'static Arc<AuthRegistry> {
203    &AUTH_REGISTRY
204}
205
206#[cfg(feature = "persistence")]
207type PoolKey = (String, Option<String>);
208#[cfg(feature = "persistence")]
209type PoolValue = Arc<tokio::sync::OnceCell<PgPool>>;
210
211#[derive(Default)]
212pub struct DbPools {
213    #[cfg(feature = "persistence")]
214    pub pools: Mutex<HashMap<PoolKey, PoolValue>>,
215}
216
217impl DbPools {
218    #[cfg(feature = "persistence")]
219    pub async fn get_pool(&self, conn_spec: &settings::DatabaseConnectionSpec) -> Result<PgPool> {
220        let db_pool_cell = {
221            let key = (conn_spec.url.clone(), conn_spec.user.clone());
222            let mut db_pools = self.pools.lock().unwrap();
223            db_pools.entry(key).or_default().clone()
224        };
225        let pool = db_pool_cell
226            .get_or_try_init(|| async move {
227                let mut pg_options: PgConnectOptions = conn_spec.url.parse()?;
228                if let Some(user) = &conn_spec.user {
229                    pg_options = pg_options.username(user);
230                }
231                if let Some(password) = &conn_spec.password {
232                    pg_options = pg_options.password(password);
233                }
234
235                // Try to connect to the database with a low timeout first.
236                {
237                    let pool_options = PgPoolOptions::new()
238                        .max_connections(1)
239                        .min_connections(1)
240                        .acquire_timeout(Duration::from_secs(30));
241                    let pool = pool_options
242                        .connect_with(pg_options.clone())
243                        .await
244                        .map_err(Error::from)
245                        .with_context(|| {
246                            format!("Failed to connect to database {}", conn_spec.url)
247                        })?;
248                    let _ = pool.acquire().await?;
249                }
250
251                // Now create the actual pool.
252                let pool_options = PgPoolOptions::new()
253                    .max_connections(conn_spec.max_connections)
254                    .min_connections(conn_spec.min_connections)
255                    .acquire_slow_level(log::LevelFilter::Info)
256                    .acquire_slow_threshold(Duration::from_secs(10))
257                    .acquire_timeout(Duration::from_secs(5 * 60));
258                let pool = pool_options
259                    .connect_with(pg_options)
260                    .await
261                    .map_err(Error::from)
262                    .with_context(|| "Failed to connect to database")?;
263                Ok::<_, Error>(pool)
264            })
265            .await?;
266        Ok(pool.clone())
267    }
268}
269
270#[cfg(feature = "persistence")]
271pub struct LibSetupContext {
272    pub all_setup_states: setup::AllSetupStates<setup::ExistingMode>,
273    pub global_setup_change: setup::GlobalSetupChange,
274}
275#[cfg(feature = "persistence")]
276pub struct PersistenceContext {
277    pub builtin_db_pool: PgPool,
278    pub setup_ctx: tokio::sync::RwLock<LibSetupContext>,
279}
280
281pub struct LibContext {
282    pub db_pools: DbPools,
283    #[cfg(feature = "persistence")]
284    pub persistence_ctx: Option<PersistenceContext>,
285    pub flows: Mutex<BTreeMap<String, Arc<FlowContext>>>,
286    pub app_namespace: String,
287    // When true, failures while dropping target backends are logged and ignored.
288    pub ignore_target_drop_failures: bool,
289    pub global_concurrency_controller: Arc<concur_control::ConcurrencyController>,
290}
291
292impl LibContext {
293    pub fn get_flow_context(&self, flow_name: &str) -> Result<Arc<FlowContext>> {
294        let flows = self.flows.lock().unwrap();
295        let flow_ctx = flows
296            .get(flow_name)
297            .ok_or_else(|| {
298                #[cfg(feature = "server")]
299                {
300                    ApiError::new(
301                        &format!("Flow instance not found: {flow_name}"),
302                        StatusCode::NOT_FOUND,
303                    )
304                }
305                #[cfg(not(feature = "server"))]
306                {
307                    anyhow::anyhow!("Flow instance not found: {flow_name}")
308                }
309            })?
310            .clone();
311        Ok(flow_ctx)
312    }
313
314    pub fn remove_flow_context(&self, flow_name: &str) {
315        let mut flows = self.flows.lock().unwrap();
316        flows.remove(flow_name);
317    }
318
319    #[cfg(feature = "persistence")]
320    pub fn require_persistence_ctx(&self) -> Result<&PersistenceContext> {
321        self.persistence_ctx.as_ref().ok_or_else(|| {
322            client_error!(
323                "Database is required for this operation. \
324                         The easiest way is to set COCOINDEX_DATABASE_URL environment variable. \
325                         Please see https://CocoIndex/docs/core/settings for more details."
326            )
327        })
328    }
329
330    #[cfg(feature = "persistence")]
331    pub fn require_builtin_db_pool(&self) -> Result<&PgPool> {
332        Ok(&self.require_persistence_ctx()?.builtin_db_pool)
333    }
334}
335
336static LIB_INIT: OnceLock<()> = OnceLock::new();
337pub async fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
338    LIB_INIT.get_or_init(|| {
339        // Initialize tracing subscriber with env filter for log level control
340        // Default to "info" level if RUST_LOG is not set
341        let env_filter =
342            EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
343        let _ = tracing_subscriber::registry()
344            .with(fmt::layer())
345            .with(env_filter)
346            .try_init();
347        #[cfg(any(feature = "server", feature = "source-gdrive", feature = "source-s3"))]
348        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
349    });
350
351    let db_pools = DbPools::default();
352    #[cfg(feature = "persistence")]
353    let persistence_ctx = if let Some(database_spec) = &settings.database {
354        let pool = db_pools.get_pool(database_spec).await?;
355        let all_setup_states = setup::get_existing_setup_state(&pool).await?;
356        Some(PersistenceContext {
357            builtin_db_pool: pool,
358            setup_ctx: tokio::sync::RwLock::new(LibSetupContext {
359                global_setup_change: setup::GlobalSetupChange::from_setup_states(&all_setup_states),
360                all_setup_states,
361            }),
362        })
363    } else {
364        // No database configured
365        None
366    };
367
368    Ok(LibContext {
369        db_pools,
370        #[cfg(feature = "persistence")]
371        persistence_ctx,
372        flows: Mutex::new(BTreeMap::new()),
373        app_namespace: settings.app_namespace,
374        ignore_target_drop_failures: settings.ignore_target_drop_failures,
375        global_concurrency_controller: Arc::new(concur_control::ConcurrencyController::new(
376            &concur_control::Options {
377                max_inflight_rows: settings.global_execution_options.source_max_inflight_rows,
378                max_inflight_bytes: settings.global_execution_options.source_max_inflight_bytes,
379            },
380        )),
381    })
382}
383
384#[allow(clippy::type_complexity)]
385static GET_SETTINGS_FN: Mutex<Option<Box<dyn Fn() -> Result<settings::Settings> + Send + Sync>>> =
386    Mutex::new(None);
387fn get_settings() -> Result<settings::Settings> {
388    let get_settings_fn = GET_SETTINGS_FN.lock().unwrap();
389    let settings = if let Some(get_settings_fn) = &*get_settings_fn {
390        get_settings_fn()?
391    } else {
392        client_bail!("CocoIndex setting function is not provided");
393    };
394    Ok(settings)
395}
396
397pub fn set_settings_fn(get_settings_fn: Box<dyn Fn() -> Result<settings::Settings> + Send + Sync>) {
398    let mut get_settings_fn_locked = GET_SETTINGS_FN.lock().unwrap();
399    *get_settings_fn_locked = Some(get_settings_fn);
400}
401
402static LIB_CONTEXT: LazyLock<tokio::sync::Mutex<Option<Arc<LibContext>>>> =
403    LazyLock::new(|| tokio::sync::Mutex::new(None));
404
405pub async fn init_lib_context(settings: Option<settings::Settings>) -> Result<()> {
406    let settings = match settings {
407        Some(settings) => settings,
408        None => get_settings()?,
409    };
410    let mut lib_context_locked = LIB_CONTEXT.lock().await;
411    *lib_context_locked = Some(Arc::new(create_lib_context(settings).await?));
412    Ok(())
413}
414
415pub async fn get_lib_context() -> Result<Arc<LibContext>> {
416    let mut lib_context_locked = LIB_CONTEXT.lock().await;
417    let lib_context = if let Some(lib_context) = &*lib_context_locked {
418        lib_context.clone()
419    } else {
420        let setting = get_settings()?;
421        let lib_context = Arc::new(create_lib_context(setting).await?);
422        *lib_context_locked = Some(lib_context.clone());
423        lib_context
424    };
425    Ok(lib_context)
426}
427
428pub async fn clear_lib_context() {
429    let mut lib_context_locked = LIB_CONTEXT.lock().await;
430    *lib_context_locked = None;
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_db_pools_default() {
439        let _db_pools = DbPools::default();
440        #[cfg(feature = "persistence")]
441        assert!(_db_pools.pools.lock().unwrap().is_empty());
442    }
443
444    #[cfg(feature = "persistence")]
445    #[tokio::test]
446    async fn test_lib_context_without_database() {
447        let lib_context = create_lib_context(settings::Settings::default())
448            .await
449            .unwrap();
450        assert!(lib_context.persistence_ctx.is_none());
451        assert!(lib_context.require_builtin_db_pool().is_err());
452    }
453
454    #[cfg(feature = "persistence")]
455    #[tokio::test]
456    async fn test_persistence_context_type_safety() {
457        // This test ensures that PersistenceContext groups related fields together
458        let settings = settings::Settings {
459            database: Some(settings::DatabaseConnectionSpec {
460                url: "postgresql://test".to_string(),
461                user: None,
462                password: None,
463                max_connections: 10,
464                min_connections: 1,
465            }),
466            ..Default::default()
467        };
468
469        // This would fail at runtime due to invalid connection, but we're testing the structure
470        let result = create_lib_context(settings).await;
471        // We expect this to fail due to invalid connection, but the structure should be correct
472        assert!(result.is_err());
473    }
474}