1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::num::NonZeroUsize;
4use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
5use std::sync::{Arc, OnceLock, RwLock};
6use std::time::{Duration, Instant};
7
8use crate::error::{KnowledgeResult, Reason};
9use async_trait::async_trait;
10use lru::LruCache;
11use orion_error::UvsFrom;
12use orion_error::conversion::ToStructError;
13use tokio::task;
14use wp_log::{debug_kdb, warn_kdb};
15use wp_model_core::model::{DataField, DataType, Value};
16
17use crate::loader::ProviderKind;
18use crate::mem::RowData;
19use crate::telemetry::{
20 CacheLayer, CacheOutcome, CacheTelemetryEvent, QueryTelemetryEvent, ReloadOutcome,
21 ReloadTelemetryEvent, telemetry, telemetry_enabled,
22};
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub struct DatasourceId(pub String);
26
27impl DatasourceId {
28 pub fn from_seed(kind: ProviderKind, seed: &str) -> Self {
29 let mut hasher = DefaultHasher::new();
30 seed.hash(&mut hasher);
31 let kind_str = match kind {
32 ProviderKind::SqliteAuthority => "sqlite",
33 ProviderKind::Postgres => "postgres",
34 ProviderKind::Mysql => "mysql",
35 };
36 Self(format!("{kind_str}:{:016x}", hasher.finish()))
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41pub struct Generation(pub u64);
42
43#[derive(Debug, Clone)]
44pub enum QueryMode {
45 Many,
46 FirstRow,
47}
48
49#[derive(Debug, Clone, Copy)]
50pub enum CachePolicy {
51 Bypass,
52 UseGlobal,
53 UseCallScope,
54}
55
56#[derive(Debug, Clone)]
57pub enum QueryValue {
58 Null,
59 Bool(bool),
60 Int(i64),
61 Float(f64),
62 Text(String),
63}
64
65#[derive(Debug, Clone)]
66pub struct QueryParam {
67 pub name: String,
68 pub value: QueryValue,
69}
70
71#[derive(Debug, Clone)]
72pub struct QueryRequest {
73 pub sql: String,
74 pub params: Vec<QueryParam>,
75 pub mode: QueryMode,
76 pub cache_policy: CachePolicy,
77}
78
79impl QueryRequest {
80 pub fn many(
81 sql: impl Into<String>,
82 params: Vec<QueryParam>,
83 cache_policy: CachePolicy,
84 ) -> Self {
85 Self {
86 sql: sql.into(),
87 params,
88 mode: QueryMode::Many,
89 cache_policy,
90 }
91 }
92
93 pub fn first_row(
94 sql: impl Into<String>,
95 params: Vec<QueryParam>,
96 cache_policy: CachePolicy,
97 ) -> Self {
98 Self {
99 sql: sql.into(),
100 params,
101 mode: QueryMode::FirstRow,
102 cache_policy,
103 }
104 }
105}
106
107#[derive(Debug, Clone)]
108pub enum QueryResponse {
109 Rows(Vec<RowData>),
110 Row(RowData),
111}
112
113impl QueryResponse {
114 pub fn into_rows(self) -> Vec<RowData> {
115 match self {
116 QueryResponse::Rows(rows) => rows,
117 QueryResponse::Row(row) => vec![row],
118 }
119 }
120
121 pub fn into_row(self) -> RowData {
122 match self {
123 QueryResponse::Rows(rows) => rows.into_iter().next().unwrap_or_default(),
124 QueryResponse::Row(row) => row,
125 }
126 }
127}
128
129#[async_trait]
130pub trait ProviderExecutor: Send + Sync {
131 fn query(&self, sql: &str) -> KnowledgeResult<Vec<RowData>>;
132 fn query_fields(&self, sql: &str, params: &[DataField]) -> KnowledgeResult<Vec<RowData>>;
133 fn query_row(&self, sql: &str) -> KnowledgeResult<RowData>;
134 fn query_named_fields(&self, sql: &str, params: &[DataField]) -> KnowledgeResult<RowData>;
135
136 async fn query_async(&self, sql: &str) -> KnowledgeResult<Vec<RowData>> {
137 self.query(sql)
138 }
139
140 async fn query_fields_async(
141 &self,
142 sql: &str,
143 params: &[DataField],
144 ) -> KnowledgeResult<Vec<RowData>> {
145 self.query_fields(sql, params)
146 }
147
148 async fn query_row_async(&self, sql: &str) -> KnowledgeResult<RowData> {
149 self.query_row(sql)
150 }
151
152 async fn query_named_fields_async(
153 &self,
154 sql: &str,
155 params: &[DataField],
156 ) -> KnowledgeResult<RowData> {
157 self.query_named_fields(sql, params)
158 }
159}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
162pub enum QueryModeTag {
163 Many,
164 FirstRow,
165}
166
167#[derive(Debug, Clone, PartialEq, Eq, Hash)]
168pub struct ResultCacheKey {
169 pub datasource_id: DatasourceId,
170 pub generation: Generation,
171 pub query_hash: u64,
172 pub params_hash: u64,
173 pub mode: QueryModeTag,
174}
175
176pub struct ProviderHandle {
177 pub provider: Arc<dyn ProviderExecutor>,
178 pub datasource_id: DatasourceId,
179 pub generation: Generation,
180 pub kind: ProviderKind,
181}
182
183#[derive(Debug, Clone)]
184pub struct RuntimeSnapshot {
185 pub provider_kind: Option<ProviderKind>,
186 pub datasource_id: Option<DatasourceId>,
187 pub generation: Option<Generation>,
188 pub result_cache_enabled: bool,
189 pub result_cache_len: usize,
190 pub result_cache_capacity: usize,
191 pub result_cache_ttl_ms: u64,
192 pub metadata_cache_len: usize,
193 pub metadata_cache_capacity: usize,
194 pub result_cache_hits: u64,
195 pub result_cache_misses: u64,
196 pub metadata_cache_hits: u64,
197 pub metadata_cache_misses: u64,
198 pub local_cache_hits: u64,
199 pub local_cache_misses: u64,
200 pub reload_successes: u64,
201 pub reload_failures: u64,
202}
203
204#[derive(Debug, Clone)]
205pub struct MetadataCacheScope {
206 pub datasource_id: DatasourceId,
207 pub generation: Generation,
208}
209
210#[derive(Debug, Clone, Copy)]
211pub struct ResultCacheConfig {
212 pub enabled: bool,
213 pub capacity: usize,
214 pub ttl: Duration,
215}
216
217impl Default for ResultCacheConfig {
218 fn default() -> Self {
219 Self {
220 enabled: true,
221 capacity: 1024,
222 ttl: Duration::from_millis(30_000),
223 }
224 }
225}
226
227#[derive(Debug, Clone)]
228struct CachedQueryResponse {
229 response: Arc<QueryResponse>,
230 cached_at: Instant,
231}
232
233pub struct KnowledgeRuntime {
234 provider: RwLock<Option<Arc<ProviderHandle>>>,
235 next_generation: AtomicU64,
236 provider_epoch: AtomicU64,
237 current_generation_value: AtomicU64,
238 result_cache_config: RwLock<ResultCacheConfig>,
239 result_cache_enabled: AtomicBool,
240 result_cache_ttl_ms: AtomicU64,
241 result_cache: RwLock<LruCache<ResultCacheKey, CachedQueryResponse>>,
242 result_cache_hits: AtomicU64,
243 result_cache_misses: AtomicU64,
244 metadata_cache_hits: AtomicU64,
245 metadata_cache_misses: AtomicU64,
246 local_cache_hits: AtomicU64,
247 local_cache_misses: AtomicU64,
248 reload_successes: AtomicU64,
249 reload_failures: AtomicU64,
250}
251
252impl KnowledgeRuntime {
253 pub fn new(result_cache_capacity: usize) -> Self {
254 let config = ResultCacheConfig {
255 capacity: result_cache_capacity.max(1),
256 ..ResultCacheConfig::default()
257 };
258 let capacity = NonZeroUsize::new(config.capacity).expect("non-zero capacity");
259 Self {
260 provider: RwLock::new(None),
261 next_generation: AtomicU64::new(0),
262 provider_epoch: AtomicU64::new(0),
263 current_generation_value: AtomicU64::new(0),
264 result_cache_config: RwLock::new(config),
265 result_cache_enabled: AtomicBool::new(config.enabled),
266 result_cache_ttl_ms: AtomicU64::new(config.ttl.as_millis() as u64),
267 result_cache: RwLock::new(LruCache::new(capacity)),
268 result_cache_hits: AtomicU64::new(0),
269 result_cache_misses: AtomicU64::new(0),
270 metadata_cache_hits: AtomicU64::new(0),
271 metadata_cache_misses: AtomicU64::new(0),
272 local_cache_hits: AtomicU64::new(0),
273 local_cache_misses: AtomicU64::new(0),
274 reload_successes: AtomicU64::new(0),
275 reload_failures: AtomicU64::new(0),
276 }
277 }
278
279 pub fn install_provider<F>(
280 &self,
281 kind: ProviderKind,
282 datasource_id: DatasourceId,
283 build: F,
284 ) -> KnowledgeResult<Generation>
285 where
286 F: FnOnce(Generation) -> KnowledgeResult<Arc<dyn ProviderExecutor>>,
287 {
288 let generation = Generation(self.next_generation.fetch_add(1, Ordering::SeqCst) + 1);
289 let previous = self
290 .provider
291 .read()
292 .ok()
293 .and_then(|guard| guard.as_ref().cloned());
294 debug_kdb!(
295 "[kdb] reload provider start kind={kind:?} datasource_id={} target_generation={} previous_generation={}",
296 datasource_id.0,
297 generation.0,
298 previous
299 .as_ref()
300 .map(|handle| handle.generation.0.to_string())
301 .unwrap_or_else(|| "none".to_string())
302 );
303 let provider = match build(generation) {
304 Ok(provider) => provider,
305 Err(err) => {
306 self.reload_failures.fetch_add(1, Ordering::Relaxed);
307 warn_kdb!(
308 "[kdb] reload provider failed kind={kind:?} datasource_id={} target_generation={} err={}",
309 datasource_id.0,
310 generation.0,
311 err
312 );
313 if telemetry_enabled() {
314 telemetry().on_reload(&ReloadTelemetryEvent {
315 outcome: ReloadOutcome::Failure,
316 provider_kind: kind.clone(),
317 });
318 }
319 return Err(err);
320 }
321 };
322 debug_kdb!(
323 "[kdb] install provider kind={kind:?} datasource_id={} generation={}",
324 datasource_id.0,
325 generation.0
326 );
327 let kind_for_handle = kind.clone();
328 let datasource_id_for_handle = datasource_id.clone();
329 let handle = Arc::new(ProviderHandle {
330 provider,
331 datasource_id: datasource_id_for_handle,
332 generation,
333 kind: kind_for_handle,
334 });
335 self.provider_epoch.fetch_add(1, Ordering::AcqRel);
336 {
337 let mut guard = self
338 .provider
339 .write()
340 .expect("runtime provider lock poisoned");
341 *guard = Some(handle);
342 }
343 self.current_generation_value
344 .store(generation.0, Ordering::Release);
345 self.provider_epoch.fetch_add(1, Ordering::Release);
346 self.reload_successes.fetch_add(1, Ordering::Relaxed);
347 if telemetry_enabled() {
348 telemetry().on_reload(&ReloadTelemetryEvent {
349 outcome: ReloadOutcome::Success,
350 provider_kind: kind.clone(),
351 });
352 }
353 debug_kdb!(
354 "[kdb] reload provider success kind={kind:?} datasource_id={} generation={}",
355 datasource_id.0,
356 generation.0
357 );
358 Ok(generation)
359 }
360
361 pub fn configure_result_cache(&self, enabled: bool, capacity: usize, ttl: Duration) {
362 let new_config = ResultCacheConfig {
363 enabled,
364 capacity: capacity.max(1),
365 ttl: ttl.max(Duration::from_millis(1)),
366 };
367 let mut should_reset_cache = false;
368 {
369 let mut guard = self
370 .result_cache_config
371 .write()
372 .expect("runtime result cache config lock poisoned");
373 if guard.capacity != new_config.capacity || (!new_config.enabled && guard.enabled) {
374 should_reset_cache = true;
375 }
376 *guard = new_config;
377 }
378 self.result_cache_enabled
379 .store(new_config.enabled, Ordering::Relaxed);
380 self.result_cache_ttl_ms.store(
381 new_config.ttl.as_millis().min(u128::from(u64::MAX)) as u64,
382 Ordering::Relaxed,
383 );
384
385 if should_reset_cache {
386 let mut cache = self
387 .result_cache
388 .write()
389 .expect("runtime result cache lock poisoned");
390 *cache = LruCache::new(
391 NonZeroUsize::new(new_config.capacity).expect("non-zero result cache capacity"),
392 );
393 }
394 }
395
396 pub fn current_generation(&self) -> Option<Generation> {
397 let epoch_before = self.provider_epoch.load(Ordering::Acquire);
398 if epoch_before % 2 == 1 {
399 return self.current_generation_from_provider();
400 }
401 let generation = self.current_generation_value.load(Ordering::Acquire);
402 let epoch_after = self.provider_epoch.load(Ordering::Acquire);
403 if epoch_before != epoch_after {
404 return self.current_generation_from_provider();
405 }
406 match generation {
407 0 => None,
408 generation => Some(Generation(generation)),
409 }
410 }
411
412 pub fn snapshot(&self) -> RuntimeSnapshot {
413 let provider = self
414 .provider
415 .read()
416 .ok()
417 .and_then(|guard| guard.as_ref().cloned());
418 let result_cache_config = self
419 .result_cache_config
420 .read()
421 .map(|guard| *guard)
422 .unwrap_or_default();
423 let (result_cache_len, result_cache_capacity) = self
424 .result_cache
425 .read()
426 .map(|cache| (cache.len(), cache.cap().get()))
427 .unwrap_or((0, 0));
428 let (metadata_cache_len, metadata_cache_capacity) =
429 crate::mem::query_util::column_metadata_cache_snapshot();
430 RuntimeSnapshot {
431 provider_kind: provider.as_ref().map(|handle| handle.kind.clone()),
432 datasource_id: provider.as_ref().map(|handle| handle.datasource_id.clone()),
433 generation: provider.as_ref().map(|handle| handle.generation),
434 result_cache_enabled: result_cache_config.enabled,
435 result_cache_len,
436 result_cache_capacity,
437 result_cache_ttl_ms: result_cache_config.ttl.as_millis() as u64,
438 metadata_cache_len,
439 metadata_cache_capacity,
440 result_cache_hits: self.result_cache_hits.load(Ordering::Relaxed),
441 result_cache_misses: self.result_cache_misses.load(Ordering::Relaxed),
442 metadata_cache_hits: self.metadata_cache_hits.load(Ordering::Relaxed),
443 metadata_cache_misses: self.metadata_cache_misses.load(Ordering::Relaxed),
444 local_cache_hits: self.local_cache_hits.load(Ordering::Relaxed),
445 local_cache_misses: self.local_cache_misses.load(Ordering::Relaxed),
446 reload_successes: self.reload_successes.load(Ordering::Relaxed),
447 reload_failures: self.reload_failures.load(Ordering::Relaxed),
448 }
449 }
450
451 pub fn current_metadata_scope(&self) -> MetadataCacheScope {
452 self.provider
453 .read()
454 .ok()
455 .and_then(|guard| guard.as_ref().cloned())
456 .map(|handle| MetadataCacheScope {
457 datasource_id: handle.datasource_id.clone(),
458 generation: handle.generation,
459 })
460 .unwrap_or_else(|| MetadataCacheScope {
461 datasource_id: DatasourceId("sqlite:standalone".to_string()),
462 generation: Generation(0),
463 })
464 }
465
466 pub fn current_provider_kind(&self) -> Option<ProviderKind> {
467 self.provider
468 .read()
469 .ok()
470 .and_then(|guard| guard.as_ref().map(|handle| handle.kind.clone()))
471 }
472
473 pub fn record_result_cache_hit(&self) {
474 self.result_cache_hits.fetch_add(1, Ordering::Relaxed);
475 }
476
477 pub fn record_result_cache_miss(&self) {
478 self.result_cache_misses.fetch_add(1, Ordering::Relaxed);
479 }
480
481 pub fn record_metadata_cache_hit(&self) {
482 self.metadata_cache_hits.fetch_add(1, Ordering::Relaxed);
483 }
484
485 pub fn record_metadata_cache_miss(&self) {
486 self.metadata_cache_misses.fetch_add(1, Ordering::Relaxed);
487 }
488
489 pub fn record_local_cache_hit(&self) {
490 self.local_cache_hits.fetch_add(1, Ordering::Relaxed);
491 }
492
493 pub fn record_local_cache_miss(&self) {
494 self.local_cache_misses.fetch_add(1, Ordering::Relaxed);
495 }
496
497 pub fn execute(&self, req: &QueryRequest) -> KnowledgeResult<QueryResponse> {
498 let handle = self.current_handle()?;
499 self.execute_with_handle(&handle, req)
500 }
501
502 fn execute_with_handle(
503 &self,
504 handle: &Arc<ProviderHandle>,
505 req: &QueryRequest,
506 ) -> KnowledgeResult<QueryResponse> {
507 let use_global_cache =
508 matches!(req.cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
509 if use_global_cache && let Some(hit) = self.fetch_result_cache(handle, req) {
510 self.record_result_cache_hit();
511 if telemetry_enabled() {
512 telemetry().on_cache(&CacheTelemetryEvent {
513 layer: CacheLayer::Result,
514 outcome: CacheOutcome::Hit,
515 provider_kind: Some(handle.kind.clone()),
516 });
517 }
518 debug_kdb!(
519 "[kdb] global result cache hit kind={:?} generation={}",
520 handle.kind,
521 handle.generation.0
522 );
523 return Ok(hit);
524 }
525 if use_global_cache {
526 self.record_result_cache_miss();
527 if telemetry_enabled() {
528 telemetry().on_cache(&CacheTelemetryEvent {
529 layer: CacheLayer::Result,
530 outcome: CacheOutcome::Miss,
531 provider_kind: Some(handle.kind.clone()),
532 });
533 }
534 debug_kdb!(
535 "[kdb] global result cache miss kind={:?} generation={}",
536 handle.kind,
537 handle.generation.0
538 );
539 }
540
541 let params = params_to_fields(&req.params);
542 let mode_tag = query_mode_tag(&req.mode);
543 let started = Instant::now();
544 debug_kdb!(
545 "[kdb] execute query kind={:?} generation={} mode={:?} cache_policy={:?}",
546 handle.kind,
547 handle.generation.0,
548 req.mode,
549 req.cache_policy
550 );
551 let response = match match req.mode {
552 QueryMode::Many => {
553 if params.is_empty() {
554 handle.provider.query(&req.sql).map(QueryResponse::Rows)
555 } else {
556 handle
557 .provider
558 .query_fields(&req.sql, ¶ms)
559 .map(QueryResponse::Rows)
560 }
561 }
562 QueryMode::FirstRow => {
563 if params.is_empty() {
564 handle.provider.query_row(&req.sql).map(QueryResponse::Row)
565 } else {
566 handle
567 .provider
568 .query_named_fields(&req.sql, ¶ms)
569 .map(QueryResponse::Row)
570 }
571 }
572 } {
573 Ok(response) => {
574 if telemetry_enabled() {
575 telemetry().on_query(&QueryTelemetryEvent {
576 provider_kind: handle.kind.clone(),
577 mode: mode_tag,
578 success: true,
579 elapsed: started.elapsed(),
580 });
581 }
582 response
583 }
584 Err(err) => {
585 if telemetry_enabled() {
586 telemetry().on_query(&QueryTelemetryEvent {
587 provider_kind: handle.kind.clone(),
588 mode: mode_tag,
589 success: false,
590 elapsed: started.elapsed(),
591 });
592 }
593 return Err(err);
594 }
595 };
596
597 if use_global_cache {
598 self.save_result_cache(handle, req, response.clone());
599 debug_kdb!(
600 "[kdb] global result cache store kind={:?} generation={}",
601 handle.kind,
602 handle.generation.0
603 );
604 }
605
606 Ok(response)
607 }
608
609 pub fn execute_first_row_fields(
610 &self,
611 sql: &str,
612 params: &[DataField],
613 cache_policy: CachePolicy,
614 ) -> KnowledgeResult<RowData> {
615 let handle = self.current_handle()?;
616 self.execute_first_row_fields_with_handle(&handle, sql, params, cache_policy)
617 }
618
619 fn execute_first_row_fields_with_handle(
620 &self,
621 handle: &Arc<ProviderHandle>,
622 sql: &str,
623 params: &[DataField],
624 cache_policy: CachePolicy,
625 ) -> KnowledgeResult<RowData> {
626 let use_global_cache =
627 matches!(cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
628 if use_global_cache
629 && let Some(hit) = self.fetch_result_cache_by_key(result_cache_key_fields(
630 handle,
631 sql,
632 params,
633 QueryModeTag::FirstRow,
634 ))
635 {
636 self.record_result_cache_hit();
637 if telemetry_enabled() {
638 telemetry().on_cache(&CacheTelemetryEvent {
639 layer: CacheLayer::Result,
640 outcome: CacheOutcome::Hit,
641 provider_kind: Some(handle.kind.clone()),
642 });
643 }
644 return Ok(hit.into_row());
645 }
646 if use_global_cache {
647 self.record_result_cache_miss();
648 if telemetry_enabled() {
649 telemetry().on_cache(&CacheTelemetryEvent {
650 layer: CacheLayer::Result,
651 outcome: CacheOutcome::Miss,
652 provider_kind: Some(handle.kind.clone()),
653 });
654 }
655 }
656
657 let started = Instant::now();
658 let row = if params.is_empty() {
659 handle.provider.query_row(sql)
660 } else {
661 handle.provider.query_named_fields(sql, params)
662 };
663 let row = match row {
664 Ok(row) => {
665 if telemetry_enabled() {
666 telemetry().on_query(&QueryTelemetryEvent {
667 provider_kind: handle.kind.clone(),
668 mode: QueryModeTag::FirstRow,
669 success: true,
670 elapsed: started.elapsed(),
671 });
672 }
673 row
674 }
675 Err(err) => {
676 if telemetry_enabled() {
677 telemetry().on_query(&QueryTelemetryEvent {
678 provider_kind: handle.kind.clone(),
679 mode: QueryModeTag::FirstRow,
680 success: false,
681 elapsed: started.elapsed(),
682 });
683 }
684 return Err(err);
685 }
686 };
687
688 if use_global_cache {
689 self.save_result_cache_by_key(
690 result_cache_key_fields(handle, sql, params, QueryModeTag::FirstRow),
691 QueryResponse::Row(row.clone()),
692 );
693 }
694
695 Ok(row)
696 }
697
698 pub async fn execute_async(&self, req: &QueryRequest) -> KnowledgeResult<QueryResponse> {
699 let handle = self.current_handle()?;
700 if matches!(handle.kind, ProviderKind::SqliteAuthority) {
701 let handle = handle.clone();
702 let req = req.clone();
703 return task::spawn_blocking(move || runtime().execute_with_handle(&handle, &req))
704 .await
705 .map_err(|err| {
706 Reason::from_logic()
707 .to_err()
708 .with_detail(format!("knowledge async sqlite query join failed: {err}"))
709 })?;
710 }
711 let use_global_cache =
712 matches!(req.cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
713 if use_global_cache && let Some(hit) = self.fetch_result_cache(&handle, req) {
714 self.record_result_cache_hit();
715 if telemetry_enabled() {
716 telemetry().on_cache(&CacheTelemetryEvent {
717 layer: CacheLayer::Result,
718 outcome: CacheOutcome::Hit,
719 provider_kind: Some(handle.kind.clone()),
720 });
721 }
722 return Ok(hit);
723 }
724 if use_global_cache {
725 self.record_result_cache_miss();
726 if telemetry_enabled() {
727 telemetry().on_cache(&CacheTelemetryEvent {
728 layer: CacheLayer::Result,
729 outcome: CacheOutcome::Miss,
730 provider_kind: Some(handle.kind.clone()),
731 });
732 }
733 }
734
735 let params = params_to_fields(&req.params);
736 let mode_tag = query_mode_tag(&req.mode);
737 let started = Instant::now();
738 let response = match req.mode {
739 QueryMode::Many => {
740 if params.is_empty() {
741 handle
742 .provider
743 .query_async(&req.sql)
744 .await
745 .map(QueryResponse::Rows)
746 } else {
747 handle
748 .provider
749 .query_fields_async(&req.sql, ¶ms)
750 .await
751 .map(QueryResponse::Rows)
752 }
753 }
754 QueryMode::FirstRow => {
755 if params.is_empty() {
756 handle
757 .provider
758 .query_row_async(&req.sql)
759 .await
760 .map(QueryResponse::Row)
761 } else {
762 handle
763 .provider
764 .query_named_fields_async(&req.sql, ¶ms)
765 .await
766 .map(QueryResponse::Row)
767 }
768 }
769 };
770 let response = match response {
771 Ok(response) => {
772 if telemetry_enabled() {
773 telemetry().on_query(&QueryTelemetryEvent {
774 provider_kind: handle.kind.clone(),
775 mode: mode_tag,
776 success: true,
777 elapsed: started.elapsed(),
778 });
779 }
780 response
781 }
782 Err(err) => {
783 if telemetry_enabled() {
784 telemetry().on_query(&QueryTelemetryEvent {
785 provider_kind: handle.kind.clone(),
786 mode: mode_tag,
787 success: false,
788 elapsed: started.elapsed(),
789 });
790 }
791 return Err(err);
792 }
793 };
794
795 if use_global_cache {
796 self.save_result_cache(&handle, req, response.clone());
797 }
798
799 Ok(response)
800 }
801
802 pub async fn execute_first_row_fields_async(
803 &self,
804 sql: &str,
805 params: &[DataField],
806 cache_policy: CachePolicy,
807 ) -> KnowledgeResult<RowData> {
808 let handle = self.current_handle()?;
809 if matches!(handle.kind, ProviderKind::SqliteAuthority) {
810 let handle = handle.clone();
811 let sql = sql.to_string();
812 let params = params.to_vec();
813 return task::spawn_blocking(move || {
814 runtime().execute_first_row_fields_with_handle(&handle, &sql, ¶ms, cache_policy)
815 })
816 .await
817 .map_err(|err| {
818 Reason::from_logic().to_err().with_detail(format!(
819 "knowledge async sqlite first-row query join failed: {err}"
820 ))
821 })?;
822 }
823 let use_global_cache =
824 matches!(cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
825 if use_global_cache
826 && let Some(hit) = self.fetch_result_cache_by_key(result_cache_key_fields(
827 &handle,
828 sql,
829 params,
830 QueryModeTag::FirstRow,
831 ))
832 {
833 self.record_result_cache_hit();
834 if telemetry_enabled() {
835 telemetry().on_cache(&CacheTelemetryEvent {
836 layer: CacheLayer::Result,
837 outcome: CacheOutcome::Hit,
838 provider_kind: Some(handle.kind.clone()),
839 });
840 }
841 return Ok(hit.into_row());
842 }
843 if use_global_cache {
844 self.record_result_cache_miss();
845 if telemetry_enabled() {
846 telemetry().on_cache(&CacheTelemetryEvent {
847 layer: CacheLayer::Result,
848 outcome: CacheOutcome::Miss,
849 provider_kind: Some(handle.kind.clone()),
850 });
851 }
852 }
853
854 let started = Instant::now();
855 let row = if params.is_empty() {
856 handle.provider.query_row_async(sql).await
857 } else {
858 handle.provider.query_named_fields_async(sql, params).await
859 };
860 let row = match row {
861 Ok(row) => {
862 if telemetry_enabled() {
863 telemetry().on_query(&QueryTelemetryEvent {
864 provider_kind: handle.kind.clone(),
865 mode: QueryModeTag::FirstRow,
866 success: true,
867 elapsed: started.elapsed(),
868 });
869 }
870 row
871 }
872 Err(err) => {
873 if telemetry_enabled() {
874 telemetry().on_query(&QueryTelemetryEvent {
875 provider_kind: handle.kind.clone(),
876 mode: QueryModeTag::FirstRow,
877 success: false,
878 elapsed: started.elapsed(),
879 });
880 }
881 return Err(err);
882 }
883 };
884
885 if use_global_cache {
886 self.save_result_cache_by_key(
887 result_cache_key_fields(&handle, sql, params, QueryModeTag::FirstRow),
888 QueryResponse::Row(row.clone()),
889 );
890 }
891
892 Ok(row)
893 }
894
895 fn current_handle(&self) -> KnowledgeResult<Arc<ProviderHandle>> {
896 self.provider
897 .read()
898 .expect("runtime provider lock poisoned")
899 .clone()
900 .ok_or_else(|| {
901 Reason::from_logic()
902 .to_err()
903 .with_detail("knowledge provider not initialized")
904 })
905 }
906
907 fn current_generation_from_provider(&self) -> Option<Generation> {
908 self.provider
909 .read()
910 .ok()
911 .and_then(|guard| guard.as_ref().map(|handle| handle.generation))
912 }
913
914 fn fetch_result_cache(
915 &self,
916 handle: &ProviderHandle,
917 req: &QueryRequest,
918 ) -> Option<QueryResponse> {
919 self.fetch_result_cache_by_key(result_cache_key(handle, req))
920 }
921
922 fn fetch_result_cache_by_key(&self, key: ResultCacheKey) -> Option<QueryResponse> {
923 if !self.result_cache_enabled() {
924 return None;
925 }
926 let cached = self
927 .result_cache
928 .read()
929 .ok()
930 .and_then(|cache| cache.peek(&key).cloned())?;
931 if cached.cached_at.elapsed() > self.result_cache_ttl() {
932 if let Ok(mut cache) = self.result_cache.write() {
933 let _ = cache.pop(&key);
934 }
935 return None;
936 }
937 Some((*cached.response).clone())
938 }
939
940 fn save_result_cache(
941 &self,
942 handle: &ProviderHandle,
943 req: &QueryRequest,
944 response: QueryResponse,
945 ) {
946 self.save_result_cache_by_key(result_cache_key(handle, req), response);
947 }
948
949 fn save_result_cache_by_key(&self, key: ResultCacheKey, response: QueryResponse) {
950 if let Ok(mut cache) = self.result_cache.write() {
951 cache.put(
952 key,
953 CachedQueryResponse {
954 response: Arc::new(response),
955 cached_at: Instant::now(),
956 },
957 );
958 }
959 }
960
961 #[inline]
962 fn result_cache_enabled(&self) -> bool {
963 self.result_cache_enabled.load(Ordering::Relaxed)
964 }
965
966 #[inline]
967 fn result_cache_ttl(&self) -> Duration {
968 Duration::from_millis(self.result_cache_ttl_ms.load(Ordering::Relaxed))
969 }
970}
971
972pub fn runtime() -> &'static KnowledgeRuntime {
973 static RUNTIME: OnceLock<KnowledgeRuntime> = OnceLock::new();
974 RUNTIME.get_or_init(|| KnowledgeRuntime::new(1024))
975}
976
977#[cfg(test)]
978pub(crate) struct RuntimeTestGuard(tokio::sync::Mutex<()>);
979
980#[cfg(test)]
981impl RuntimeTestGuard {
982 pub(crate) fn lock(&self) -> Result<tokio::sync::MutexGuard<'_, ()>, std::convert::Infallible> {
983 Ok(self.0.blocking_lock())
984 }
985
986 pub(crate) async fn lock_async(&self) -> tokio::sync::MutexGuard<'_, ()> {
987 self.0.lock().await
988 }
989}
990
991#[cfg(test)]
992pub(crate) fn runtime_test_guard() -> &'static RuntimeTestGuard {
993 static GUARD: OnceLock<RuntimeTestGuard> = OnceLock::new();
994 GUARD.get_or_init(|| RuntimeTestGuard(tokio::sync::Mutex::new(())))
995}
996
997fn result_cache_key(handle: &ProviderHandle, req: &QueryRequest) -> ResultCacheKey {
998 ResultCacheKey {
999 datasource_id: handle.datasource_id.clone(),
1000 generation: handle.generation,
1001 query_hash: stable_hash(&req.sql),
1002 params_hash: stable_params_hash(&req.params),
1003 mode: match req.mode {
1004 QueryMode::Many => QueryModeTag::Many,
1005 QueryMode::FirstRow => QueryModeTag::FirstRow,
1006 },
1007 }
1008}
1009
1010fn result_cache_key_fields(
1011 handle: &ProviderHandle,
1012 sql: &str,
1013 params: &[DataField],
1014 mode: QueryModeTag,
1015) -> ResultCacheKey {
1016 ResultCacheKey {
1017 datasource_id: handle.datasource_id.clone(),
1018 generation: handle.generation,
1019 query_hash: stable_hash(sql),
1020 params_hash: stable_field_params_hash(params),
1021 mode,
1022 }
1023}
1024
1025fn query_mode_tag(mode: &QueryMode) -> QueryModeTag {
1026 match mode {
1027 QueryMode::Many => QueryModeTag::Many,
1028 QueryMode::FirstRow => QueryModeTag::FirstRow,
1029 }
1030}
1031
1032fn stable_hash(value: &str) -> u64 {
1033 let mut hasher = DefaultHasher::new();
1034 value.hash(&mut hasher);
1035 hasher.finish()
1036}
1037
1038fn stable_params_hash(params: &[QueryParam]) -> u64 {
1039 let mut hasher = DefaultHasher::new();
1040 for param in params {
1041 param.name.hash(&mut hasher);
1042 match ¶m.value {
1043 QueryValue::Null => 0u8.hash(&mut hasher),
1044 QueryValue::Bool(value) => {
1045 1u8.hash(&mut hasher);
1046 value.hash(&mut hasher);
1047 }
1048 QueryValue::Int(value) => {
1049 2u8.hash(&mut hasher);
1050 value.hash(&mut hasher);
1051 }
1052 QueryValue::Float(value) => {
1053 3u8.hash(&mut hasher);
1054 value.to_bits().hash(&mut hasher);
1055 }
1056 QueryValue::Text(value) => {
1057 4u8.hash(&mut hasher);
1058 value.hash(&mut hasher);
1059 }
1060 }
1061 }
1062 hasher.finish()
1063}
1064
1065fn stable_field_params_hash(params: &[DataField]) -> u64 {
1066 let mut hasher = DefaultHasher::new();
1067 for field in params {
1068 field.get_name().hash(&mut hasher);
1069 match field.get_value() {
1070 Value::Null | Value::Ignore(_) => 0u8.hash(&mut hasher),
1071 Value::Bool(value) => {
1072 1u8.hash(&mut hasher);
1073 value.hash(&mut hasher);
1074 }
1075 Value::Digit(value) => {
1076 2u8.hash(&mut hasher);
1077 value.hash(&mut hasher);
1078 }
1079 Value::Float(value) => {
1080 3u8.hash(&mut hasher);
1081 value.to_bits().hash(&mut hasher);
1082 }
1083 Value::Chars(value) => {
1084 4u8.hash(&mut hasher);
1085 value.hash(&mut hasher);
1086 }
1087 Value::Symbol(value) => {
1088 5u8.hash(&mut hasher);
1089 value.hash(&mut hasher);
1090 }
1091 Value::Time(value) => {
1092 6u8.hash(&mut hasher);
1093 value.hash(&mut hasher);
1094 }
1095 Value::Hex(value) => {
1096 7u8.hash(&mut hasher);
1097 value.to_string().hash(&mut hasher);
1098 }
1099 Value::IpNet(value) => {
1100 8u8.hash(&mut hasher);
1101 value.to_string().hash(&mut hasher);
1102 }
1103 Value::IpAddr(value) => {
1104 9u8.hash(&mut hasher);
1105 value.hash(&mut hasher);
1106 }
1107 Value::Obj(value) => {
1108 10u8.hash(&mut hasher);
1109 format!("{:?}", value).hash(&mut hasher);
1110 }
1111 Value::Array(value) => {
1112 11u8.hash(&mut hasher);
1113 format!("{:?}", value).hash(&mut hasher);
1114 }
1115 Value::Domain(value) => {
1116 12u8.hash(&mut hasher);
1117 value.0.hash(&mut hasher);
1118 }
1119 Value::Url(value) => {
1120 13u8.hash(&mut hasher);
1121 value.0.hash(&mut hasher);
1122 }
1123 Value::Email(value) => {
1124 14u8.hash(&mut hasher);
1125 value.0.hash(&mut hasher);
1126 }
1127 Value::IdCard(value) => {
1128 15u8.hash(&mut hasher);
1129 value.0.hash(&mut hasher);
1130 }
1131 Value::MobilePhone(value) => {
1132 16u8.hash(&mut hasher);
1133 value.0.hash(&mut hasher);
1134 }
1135 }
1136 }
1137 hasher.finish()
1138}
1139
1140pub fn fields_to_params(params: &[DataField]) -> Vec<QueryParam> {
1141 params
1142 .iter()
1143 .map(|field| {
1144 let value = match field.get_value() {
1145 Value::Null | Value::Ignore(_) => QueryValue::Null,
1146 Value::Bool(value) => QueryValue::Bool(*value),
1147 Value::Digit(value) => QueryValue::Int(*value),
1148 Value::Float(value) => QueryValue::Float(*value),
1149 Value::Chars(value) => QueryValue::Text(value.to_string()),
1150 Value::Symbol(value) => QueryValue::Text(value.to_string()),
1151 Value::Time(value) => QueryValue::Text(value.to_string()),
1152 Value::Hex(value) => QueryValue::Text(value.to_string()),
1153 Value::IpNet(value) => QueryValue::Text(value.to_string()),
1154 Value::IpAddr(value) => QueryValue::Text(value.to_string()),
1155 Value::Obj(value) => QueryValue::Text(format!("{:?}", value)),
1156 Value::Array(value) => QueryValue::Text(format!("{:?}", value)),
1157 Value::Domain(value) => QueryValue::Text(value.0.to_string()),
1158 Value::Url(value) => QueryValue::Text(value.0.to_string()),
1159 Value::Email(value) => QueryValue::Text(value.0.to_string()),
1160 Value::IdCard(value) => QueryValue::Text(value.0.to_string()),
1161 Value::MobilePhone(value) => QueryValue::Text(value.0.to_string()),
1162 };
1163 QueryParam {
1164 name: field.get_name().to_string(),
1165 value,
1166 }
1167 })
1168 .collect()
1169}
1170
1171pub fn params_to_fields(params: &[QueryParam]) -> Vec<DataField> {
1172 params
1173 .iter()
1174 .map(|param| match ¶m.value {
1175 QueryValue::Null => {
1176 DataField::new(DataType::default(), param.name.clone(), Value::Null)
1177 }
1178 QueryValue::Bool(value) => {
1179 DataField::new(DataType::default(), param.name.clone(), Value::Bool(*value))
1180 }
1181 QueryValue::Int(value) => DataField::from_digit(param.name.clone(), *value),
1182 QueryValue::Float(value) => DataField::from_float(param.name.clone(), *value),
1183 QueryValue::Text(value) => DataField::from_chars(param.name.clone(), value.clone()),
1184 })
1185 .collect()
1186}
1187
1188#[cfg(test)]
1189mod tests {
1190 use super::*;
1191 use async_trait::async_trait;
1192 use std::sync::Arc;
1193 use wp_model_core::model::Value;
1194
1195 struct TestProvider {
1196 value: &'static str,
1197 }
1198
1199 #[async_trait]
1200 impl ProviderExecutor for TestProvider {
1201 fn query(&self, _sql: &str) -> KnowledgeResult<Vec<RowData>> {
1202 Ok(vec![vec![DataField::from_chars("value", self.value)]])
1203 }
1204
1205 fn query_fields(&self, _sql: &str, _params: &[DataField]) -> KnowledgeResult<Vec<RowData>> {
1206 self.query("")
1207 }
1208
1209 fn query_row(&self, _sql: &str) -> KnowledgeResult<RowData> {
1210 Ok(vec![DataField::from_chars("value", self.value)])
1211 }
1212
1213 fn query_named_fields(
1214 &self,
1215 _sql: &str,
1216 _params: &[DataField],
1217 ) -> KnowledgeResult<RowData> {
1218 self.query_row("")
1219 }
1220 }
1221
1222 #[test]
1223 fn query_param_hash_is_stable() {
1224 let params = vec![
1225 QueryParam {
1226 name: ":id".to_string(),
1227 value: QueryValue::Int(7),
1228 },
1229 QueryParam {
1230 name: ":name".to_string(),
1231 value: QueryValue::Text("abc".to_string()),
1232 },
1233 ];
1234 assert_eq!(stable_params_hash(¶ms), stable_params_hash(¶ms));
1235 }
1236
1237 #[test]
1238 fn fields_to_params_preserves_raw_chars_value() {
1239 let fields = [DataField::from_chars(
1240 ":name".to_string(),
1241 "令狐冲".to_string(),
1242 )];
1243 let params = fields_to_params(&fields);
1244 assert_eq!(params.len(), 1);
1245 match ¶ms[0].value {
1246 QueryValue::Text(value) => assert_eq!(value, "令狐冲"),
1247 other => panic!("unexpected param value: {other:?}"),
1248 }
1249 let roundtrip = params_to_fields(¶ms);
1250 assert!(matches!(roundtrip[0].get_value(), Value::Chars(_)));
1251 }
1252
1253 #[tokio::test(flavor = "current_thread")]
1254 async fn sqlite_async_bridge_keeps_captured_handle_after_reload() {
1255 let _guard = runtime_test_guard().lock_async().await;
1256 runtime()
1257 .install_provider(
1258 ProviderKind::SqliteAuthority,
1259 DatasourceId("sqlite:old".to_string()),
1260 |_generation| Ok(Arc::new(TestProvider { value: "old" })),
1261 )
1262 .expect("install old provider");
1263 let old_handle = runtime().current_handle().expect("current old handle");
1264
1265 runtime()
1266 .install_provider(
1267 ProviderKind::SqliteAuthority,
1268 DatasourceId("sqlite:new".to_string()),
1269 |_generation| Ok(Arc::new(TestProvider { value: "new" })),
1270 )
1271 .expect("install new provider");
1272
1273 let req = QueryRequest::first_row("SELECT value", Vec::new(), CachePolicy::Bypass);
1274 let row = task::spawn_blocking(move || runtime().execute_with_handle(&old_handle, &req))
1275 .await
1276 .expect("join sqlite bridge")
1277 .expect("execute old handle")
1278 .into_row();
1279 assert_eq!(row[0].to_string(), "chars(old)");
1280 }
1281}