1#[cfg(feature = "persistence")]
14use std::time::Duration;
15
16use crate::prelude::*;
17
18use crate::builder::AnalyzedFlow;
19#[cfg(feature = "persistence")]
20use crate::execution::source_indexer::SourceIndexingContext;
21#[cfg(feature = "persistence")]
22use crate::service::query_handler::{QueryHandler, QueryHandlerSpec};
23use crate::settings;
24#[cfg(feature = "persistence")]
25use crate::setup::ObjectSetupChange;
26#[cfg(feature = "server")]
27use axum::http::StatusCode;
28#[cfg(feature = "server")]
29use recoco_utils::error::ApiError;
30#[cfg(feature = "persistence")]
31use sqlx::PgPool;
32#[cfg(feature = "persistence")]
33use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
34use tokio::runtime::Runtime;
35use tracing_subscriber::{EnvFilter, fmt, prelude::*};
36
37#[cfg(feature = "persistence")]
38pub struct FlowExecutionContext {
39 pub setup_execution_context: Arc<exec_ctx::FlowSetupExecutionContext>,
40 pub setup_change: setup::FlowSetupChange,
41 source_indexing_contexts: Vec<tokio::sync::OnceCell<Arc<SourceIndexingContext>>>,
42}
43
44#[cfg(feature = "persistence")]
45async fn build_setup_context(
46 analyzed_flow: &AnalyzedFlow,
47 existing_flow_ss: Option<&setup::FlowSetupState<setup::ExistingMode>>,
48) -> Result<(
49 Arc<exec_ctx::FlowSetupExecutionContext>,
50 setup::FlowSetupChange,
51)> {
52 let setup_execution_context = Arc::new(exec_ctx::build_flow_setup_execution_context(
53 &analyzed_flow.flow_instance,
54 &analyzed_flow.data_schema,
55 &analyzed_flow.setup_state,
56 existing_flow_ss,
57 )?);
58
59 let setup_change = setup::diff_flow_setup_states(
60 Some(&setup_execution_context.setup_state),
61 existing_flow_ss,
62 &analyzed_flow.flow_instance_ctx,
63 )
64 .await?;
65
66 Ok((setup_execution_context, setup_change))
67}
68
69#[cfg(feature = "persistence")]
70impl FlowExecutionContext {
71 async fn new(
72 analyzed_flow: &AnalyzedFlow,
73 existing_flow_ss: Option<&setup::FlowSetupState<setup::ExistingMode>>,
74 ) -> Result<Self> {
75 let (setup_execution_context, setup_change) =
76 build_setup_context(analyzed_flow, existing_flow_ss).await?;
77
78 let mut source_indexing_contexts = Vec::new();
79 source_indexing_contexts.resize_with(analyzed_flow.flow_instance.import_ops.len(), || {
80 tokio::sync::OnceCell::new()
81 });
82
83 Ok(Self {
84 setup_execution_context,
85 setup_change,
86 source_indexing_contexts,
87 })
88 }
89
90 pub async fn update_setup_state(
91 &mut self,
92 analyzed_flow: &AnalyzedFlow,
93 existing_flow_ss: Option<&setup::FlowSetupState<setup::ExistingMode>>,
94 ) -> Result<()> {
95 let (setup_execution_context, setup_change) =
96 build_setup_context(analyzed_flow, existing_flow_ss).await?;
97
98 self.setup_execution_context = setup_execution_context;
99 self.setup_change = setup_change;
100 Ok(())
101 }
102
103 pub async fn get_source_indexing_context(
104 &self,
105 flow: &Arc<AnalyzedFlow>,
106 source_idx: usize,
107 pool: &PgPool,
108 ) -> Result<&Arc<SourceIndexingContext>> {
109 self.source_indexing_contexts[source_idx]
110 .get_or_try_init(|| async move {
111 SourceIndexingContext::load(
112 flow.clone(),
113 source_idx,
114 self.setup_execution_context.clone(),
115 pool,
116 )
117 .await
118 })
119 .await
120 }
121}
122
123#[cfg(feature = "persistence")]
124pub struct QueryHandlerContext {
125 pub info: Arc<QueryHandlerSpec>,
126 pub handler: Arc<dyn QueryHandler>,
127}
128
129pub struct FlowContext {
130 pub flow: Arc<AnalyzedFlow>,
131 #[cfg(feature = "persistence")]
132 execution_ctx: Arc<tokio::sync::RwLock<FlowExecutionContext>>,
133 #[cfg(feature = "persistence")]
134 pub query_handlers: RwLock<HashMap<String, QueryHandlerContext>>,
135}
136
137impl FlowContext {
138 pub fn flow_name(&self) -> &str {
139 &self.flow.flow_instance.name
140 }
141
142 #[cfg(feature = "persistence")]
143 pub async fn new(
144 flow: Arc<AnalyzedFlow>,
145 existing_flow_ss: Option<&setup::FlowSetupState<setup::ExistingMode>>,
146 ) -> Result<Self> {
147 let execution_ctx = Arc::new(tokio::sync::RwLock::new(
148 FlowExecutionContext::new(&flow, existing_flow_ss).await?,
149 ));
150 Ok(Self {
151 flow,
152 execution_ctx,
153 query_handlers: RwLock::new(HashMap::new()),
154 })
155 }
156
157 #[cfg(not(feature = "persistence"))]
158 pub fn new_transient(flow: Arc<AnalyzedFlow>) -> Self {
159 Self { flow }
160 }
161
162 #[cfg(feature = "persistence")]
163 pub async fn use_execution_ctx(
164 &self,
165 ) -> Result<tokio::sync::RwLockReadGuard<'_, FlowExecutionContext>> {
166 let execution_ctx = self.execution_ctx.read().await;
167 if !execution_ctx.setup_change.is_up_to_date() {
168 api_bail!(
169 "Setup for flow `{}` is not up-to-date. Please run `cocoindex setup` to update the setup.",
170 self.flow_name()
171 );
172 }
173 Ok(execution_ctx)
174 }
175
176 #[cfg(feature = "persistence")]
177 pub async fn use_owned_execution_ctx(
178 &self,
179 ) -> Result<tokio::sync::OwnedRwLockReadGuard<FlowExecutionContext>> {
180 let execution_ctx = self.execution_ctx.clone().read_owned().await;
181 if !execution_ctx.setup_change.is_up_to_date() {
182 api_bail!(
183 "Setup for flow `{}` is not up-to-date. Please run `cocoindex setup` to update the setup.",
184 self.flow_name()
185 );
186 }
187 Ok(execution_ctx)
188 }
189
190 #[cfg(feature = "persistence")]
191 pub fn get_execution_ctx_for_setup(&self) -> &tokio::sync::RwLock<FlowExecutionContext> {
192 &self.execution_ctx
193 }
194}
195
196static TOKIO_RUNTIME: LazyLock<Runtime> = LazyLock::new(|| Runtime::new().unwrap());
197static AUTH_REGISTRY: LazyLock<Arc<AuthRegistry>> = LazyLock::new(|| Arc::new(AuthRegistry::new()));
198
199pub fn get_runtime() -> &'static Runtime {
200 &TOKIO_RUNTIME
201}
202pub fn get_auth_registry() -> &'static Arc<AuthRegistry> {
203 &AUTH_REGISTRY
204}
205
206#[cfg(feature = "persistence")]
207type PoolKey = (String, Option<String>);
208#[cfg(feature = "persistence")]
209type PoolValue = Arc<tokio::sync::OnceCell<PgPool>>;
210
211#[derive(Default)]
212pub struct DbPools {
213 #[cfg(feature = "persistence")]
214 pub pools: Mutex<HashMap<PoolKey, PoolValue>>,
215}
216
217impl DbPools {
218 #[cfg(feature = "persistence")]
219 pub async fn get_pool(&self, conn_spec: &settings::DatabaseConnectionSpec) -> Result<PgPool> {
220 let db_pool_cell = {
221 let key = (conn_spec.url.clone(), conn_spec.user.clone());
222 let mut db_pools = self.pools.lock().unwrap();
223 db_pools.entry(key).or_default().clone()
224 };
225 let pool = db_pool_cell
226 .get_or_try_init(|| async move {
227 let mut pg_options: PgConnectOptions = conn_spec.url.parse()?;
228 if let Some(user) = &conn_spec.user {
229 pg_options = pg_options.username(user);
230 }
231 if let Some(password) = &conn_spec.password {
232 pg_options = pg_options.password(password);
233 }
234
235 {
237 let pool_options = PgPoolOptions::new()
238 .max_connections(1)
239 .min_connections(1)
240 .acquire_timeout(Duration::from_secs(30));
241 let pool = pool_options
242 .connect_with(pg_options.clone())
243 .await
244 .map_err(Error::from)
245 .with_context(|| {
246 format!("Failed to connect to database {}", conn_spec.url)
247 })?;
248 let _ = pool.acquire().await?;
249 }
250
251 let pool_options = PgPoolOptions::new()
253 .max_connections(conn_spec.max_connections)
254 .min_connections(conn_spec.min_connections)
255 .acquire_slow_level(log::LevelFilter::Info)
256 .acquire_slow_threshold(Duration::from_secs(10))
257 .acquire_timeout(Duration::from_secs(5 * 60));
258 let pool = pool_options
259 .connect_with(pg_options)
260 .await
261 .map_err(Error::from)
262 .with_context(|| "Failed to connect to database")?;
263 Ok::<_, Error>(pool)
264 })
265 .await?;
266 Ok(pool.clone())
267 }
268}
269
270#[cfg(feature = "persistence")]
271pub struct LibSetupContext {
272 pub all_setup_states: setup::AllSetupStates<setup::ExistingMode>,
273 pub global_setup_change: setup::GlobalSetupChange,
274}
275#[cfg(feature = "persistence")]
276pub struct PersistenceContext {
277 pub builtin_db_pool: PgPool,
278 pub setup_ctx: tokio::sync::RwLock<LibSetupContext>,
279}
280
281pub struct LibContext {
282 pub db_pools: DbPools,
283 #[cfg(feature = "persistence")]
284 pub persistence_ctx: Option<PersistenceContext>,
285 pub flows: Mutex<BTreeMap<String, Arc<FlowContext>>>,
286 pub app_namespace: String,
287 pub ignore_target_drop_failures: bool,
289 pub global_concurrency_controller: Arc<concur_control::ConcurrencyController>,
290}
291
292impl LibContext {
293 pub fn get_flow_context(&self, flow_name: &str) -> Result<Arc<FlowContext>> {
294 let flows = self.flows.lock().unwrap();
295 let flow_ctx = flows
296 .get(flow_name)
297 .ok_or_else(|| {
298 #[cfg(feature = "server")]
299 {
300 ApiError::new(
301 &format!("Flow instance not found: {flow_name}"),
302 StatusCode::NOT_FOUND,
303 )
304 }
305 #[cfg(not(feature = "server"))]
306 {
307 anyhow::anyhow!("Flow instance not found: {flow_name}")
308 }
309 })?
310 .clone();
311 Ok(flow_ctx)
312 }
313
314 pub fn remove_flow_context(&self, flow_name: &str) {
315 let mut flows = self.flows.lock().unwrap();
316 flows.remove(flow_name);
317 }
318
319 #[cfg(feature = "persistence")]
320 pub fn require_persistence_ctx(&self) -> Result<&PersistenceContext> {
321 self.persistence_ctx.as_ref().ok_or_else(|| {
322 client_error!(
323 "Database is required for this operation. \
324 The easiest way is to set COCOINDEX_DATABASE_URL environment variable. \
325 Please see https://CocoIndex/docs/core/settings for more details."
326 )
327 })
328 }
329
330 #[cfg(feature = "persistence")]
331 pub fn require_builtin_db_pool(&self) -> Result<&PgPool> {
332 Ok(&self.require_persistence_ctx()?.builtin_db_pool)
333 }
334}
335
336static LIB_INIT: OnceLock<()> = OnceLock::new();
337pub async fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
338 LIB_INIT.get_or_init(|| {
339 let env_filter =
342 EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
343 let _ = tracing_subscriber::registry()
344 .with(fmt::layer())
345 .with(env_filter)
346 .try_init();
347 #[cfg(any(feature = "server", feature = "source-gdrive", feature = "source-s3"))]
348 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
349 });
350
351 let db_pools = DbPools::default();
352 #[cfg(feature = "persistence")]
353 let persistence_ctx = if let Some(database_spec) = &settings.database {
354 let pool = db_pools.get_pool(database_spec).await?;
355 let all_setup_states = setup::get_existing_setup_state(&pool).await?;
356 Some(PersistenceContext {
357 builtin_db_pool: pool,
358 setup_ctx: tokio::sync::RwLock::new(LibSetupContext {
359 global_setup_change: setup::GlobalSetupChange::from_setup_states(&all_setup_states),
360 all_setup_states,
361 }),
362 })
363 } else {
364 None
366 };
367
368 Ok(LibContext {
369 db_pools,
370 #[cfg(feature = "persistence")]
371 persistence_ctx,
372 flows: Mutex::new(BTreeMap::new()),
373 app_namespace: settings.app_namespace,
374 ignore_target_drop_failures: settings.ignore_target_drop_failures,
375 global_concurrency_controller: Arc::new(concur_control::ConcurrencyController::new(
376 &concur_control::Options {
377 max_inflight_rows: settings.global_execution_options.source_max_inflight_rows,
378 max_inflight_bytes: settings.global_execution_options.source_max_inflight_bytes,
379 },
380 )),
381 })
382}
383
384#[allow(clippy::type_complexity)]
385static GET_SETTINGS_FN: Mutex<Option<Box<dyn Fn() -> Result<settings::Settings> + Send + Sync>>> =
386 Mutex::new(None);
387fn get_settings() -> Result<settings::Settings> {
388 let get_settings_fn = GET_SETTINGS_FN.lock().unwrap();
389 let settings = if let Some(get_settings_fn) = &*get_settings_fn {
390 get_settings_fn()?
391 } else {
392 client_bail!("CocoIndex setting function is not provided");
393 };
394 Ok(settings)
395}
396
397pub fn set_settings_fn(get_settings_fn: Box<dyn Fn() -> Result<settings::Settings> + Send + Sync>) {
398 let mut get_settings_fn_locked = GET_SETTINGS_FN.lock().unwrap();
399 *get_settings_fn_locked = Some(get_settings_fn);
400}
401
402static LIB_CONTEXT: LazyLock<tokio::sync::Mutex<Option<Arc<LibContext>>>> =
403 LazyLock::new(|| tokio::sync::Mutex::new(None));
404
405pub async fn init_lib_context(settings: Option<settings::Settings>) -> Result<()> {
406 let settings = match settings {
407 Some(settings) => settings,
408 None => get_settings()?,
409 };
410 let mut lib_context_locked = LIB_CONTEXT.lock().await;
411 *lib_context_locked = Some(Arc::new(create_lib_context(settings).await?));
412 Ok(())
413}
414
415pub async fn get_lib_context() -> Result<Arc<LibContext>> {
416 let mut lib_context_locked = LIB_CONTEXT.lock().await;
417 let lib_context = if let Some(lib_context) = &*lib_context_locked {
418 lib_context.clone()
419 } else {
420 let setting = get_settings()?;
421 let lib_context = Arc::new(create_lib_context(setting).await?);
422 *lib_context_locked = Some(lib_context.clone());
423 lib_context
424 };
425 Ok(lib_context)
426}
427
428pub async fn clear_lib_context() {
429 let mut lib_context_locked = LIB_CONTEXT.lock().await;
430 *lib_context_locked = None;
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_db_pools_default() {
439 let _db_pools = DbPools::default();
440 #[cfg(feature = "persistence")]
441 assert!(_db_pools.pools.lock().unwrap().is_empty());
442 }
443
444 #[cfg(feature = "persistence")]
445 #[tokio::test]
446 async fn test_lib_context_without_database() {
447 let lib_context = create_lib_context(settings::Settings::default())
448 .await
449 .unwrap();
450 assert!(lib_context.persistence_ctx.is_none());
451 assert!(lib_context.require_builtin_db_pool().is_err());
452 }
453
454 #[cfg(feature = "persistence")]
455 #[tokio::test]
456 async fn test_persistence_context_type_safety() {
457 let settings = settings::Settings {
459 database: Some(settings::DatabaseConnectionSpec {
460 url: "postgresql://test".to_string(),
461 user: None,
462 password: None,
463 max_connections: 10,
464 min_connections: 1,
465 }),
466 ..Default::default()
467 };
468
469 let result = create_lib_context(settings).await;
471 assert!(result.is_err());
473 }
474}