1use std::path::Path;
2use std::sync::Arc;
3
4use crate::cache::CacheAble;
5use std::collections::HashSet;
6use std::sync::OnceLock;
7use wp_error::{KnowledgeReason, KnowledgeResult};
8use wp_log::info_ctrl;
9use wp_model_core::model::DataField;
10
11use crate::DBQuery;
12use crate::mem::RowData;
13use crate::mem::memdb::MemDB;
14use crate::mem::thread_clone::ThreadClonedMDB;
15use orion_error::{ErrorWith, ToStructError, UvsFrom};
17use rusqlite::ToSql;
18use rusqlite::{Connection, OpenFlags};
19
20pub trait QueryFacade: Send + Sync {
23 fn query(&self, sql: &str) -> KnowledgeResult<Vec<RowData>>;
24 fn query_row(&self, sql: &str) -> KnowledgeResult<RowData>;
25 fn query_named<'a>(
26 &self,
27 sql: &str,
28 params: &'a [(&'a str, &'a dyn ToSql)],
29 ) -> KnowledgeResult<RowData>;
30 fn query_cipher(&self, table: &str) -> KnowledgeResult<Vec<String>>;
31}
32
33impl QueryFacade for ThreadClonedMDB {
34 fn query(&self, sql: &str) -> KnowledgeResult<Vec<RowData>> {
35 DBQuery::query(self, sql)
36 }
37 fn query_row(&self, sql: &str) -> KnowledgeResult<RowData> {
38 DBQuery::query_row(self, sql)
39 }
40 fn query_named<'a>(
41 &self,
42 sql: &str,
43 params: &'a [(&'a str, &'a dyn ToSql)],
44 ) -> KnowledgeResult<RowData> {
45 DBQuery::query_row_params(self, sql, params)
46 }
47 fn query_cipher(&self, table: &str) -> KnowledgeResult<Vec<String>> {
48 DBQuery::query_cipher(self, table)
49 }
50}
51
52struct MemProvider(MemDB);
53impl QueryFacade for MemProvider {
54 fn query(&self, sql: &str) -> KnowledgeResult<Vec<RowData>> {
55 DBQuery::query(&self.0, sql)
56 }
57 fn query_row(&self, sql: &str) -> KnowledgeResult<RowData> {
58 DBQuery::query_row(&self.0, sql)
59 }
60 fn query_named<'a>(
61 &self,
62 sql: &str,
63 params: &'a [(&'a str, &'a dyn ToSql)],
64 ) -> KnowledgeResult<RowData> {
65 DBQuery::query_row_params(&self.0, sql, params)
66 }
67 fn query_cipher(&self, table: &str) -> KnowledgeResult<Vec<String>> {
68 DBQuery::query_cipher(&self.0, table)
69 }
70}
71
72static PROVIDER: OnceLock<Arc<dyn QueryFacade>> = OnceLock::new();
73static TABLE_WHITELIST: OnceLock<HashSet<String>> = OnceLock::new();
74
75pub fn init_thread_cloned_from_authority(authority_uri: &str) -> KnowledgeResult<()> {
77 let tc = ThreadClonedMDB::from_authority(authority_uri);
78 set_provider(Arc::new(tc))
79}
80
81pub fn init_mem_provider(memdb: MemDB) -> KnowledgeResult<()> {
83 let res = set_provider(Arc::new(MemProvider(memdb)));
84 if res.is_err() {
85 eprintln!("[kdb] provider already initialized");
86 } else {
87 eprintln!("[kdb] provider set to MemProvider");
88 }
89 res
90}
91
92fn set_provider(p: Arc<dyn QueryFacade>) -> KnowledgeResult<()> {
93 PROVIDER.set(p).map_err(|_| {
94 KnowledgeReason::from_logic()
95 .to_err()
96 .with_detail("knowledge provider already initialized")
97 })
98}
99
100fn get_provider() -> KnowledgeResult<&'static Arc<dyn QueryFacade>> {
101 PROVIDER.get().ok_or_else(|| {
102 KnowledgeReason::from_logic()
103 .to_err()
104 .with_detail("knowledge provider not initialized")
105 })
106}
107
108pub fn query(sql: &str) -> KnowledgeResult<Vec<RowData>> {
109 get_provider()?.query(sql)
110}
111
112pub fn query_row(sql: &str) -> KnowledgeResult<RowData> {
114 get_provider()?.query_row(sql)
115}
116
117pub fn query_named<'a>(
119 sql: &str,
120 params: &'a [(&'a str, &'a dyn ToSql)],
121) -> KnowledgeResult<RowData> {
122 get_provider()?.query_named(sql, params)
123}
124
125pub fn query_cipher(table: &str) -> KnowledgeResult<Vec<String>> {
127 if let Some(wl) = TABLE_WHITELIST.get()
128 && !wl.contains(table)
129 {
130 return Err(KnowledgeReason::from_logic()
131 .to_err()
132 .with_detail("table not allowed by knowdb whitelist")
133 .with(("table", table)));
134 }
135 get_provider()?.query_cipher(table)
136}
137
138pub fn cache_query<const N: usize>(
143 sql: &str,
144 c_params: &[DataField; N],
145 named_params: &[(&str, &dyn ToSql)],
146 cache: &mut impl CacheAble<DataField, RowData, N>,
147) -> RowData {
148 crate::cache_util::cache_query_impl(c_params, cache, || {
149 if named_params.is_empty() {
150 get_provider().and_then(|p| p.query_row(sql))
151 } else {
152 get_provider().and_then(|p| p.query_named(sql, named_params))
153 }
154 })
155}
156
157fn ensure_wal(authority_uri: &str) -> KnowledgeResult<()> {
158 if let Ok(conn) = Connection::open_with_flags(
160 authority_uri,
161 OpenFlags::SQLITE_OPEN_READ_WRITE
162 | OpenFlags::SQLITE_OPEN_CREATE
163 | OpenFlags::SQLITE_OPEN_URI,
164 ) {
165 let _ = conn.execute_batch(
166 "PRAGMA journal_mode=WAL;\nPRAGMA synchronous=NORMAL;\nPRAGMA temp_store=MEMORY;",
167 );
168 }
169 Ok(())
170}
171
172pub fn init_wal_pool_from_authority(authority_uri: &str, pool_size: u32) -> KnowledgeResult<()> {
174 ensure_wal(authority_uri)?;
175 let flags = OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_URI;
176 let mem = MemDB::new_file(authority_uri, pool_size, flags)?;
177 init_mem_provider(mem)
178}
179
180pub fn init_thread_cloned_from_knowdb(
184 root: &Path,
185 knowdb_conf: &Path,
186 authority_uri: &str,
187 dict: &orion_variate::EnvDict,
188) -> KnowledgeResult<()> {
189 let tables =
190 crate::loader::build_authority_from_knowdb(root, knowdb_conf, authority_uri, dict)?;
191 let ro_uri = if let Some(rest) = authority_uri.strip_prefix("file:") {
193 let path_part = rest.split('?').next().unwrap_or(rest);
194 format!("file:{}?mode=ro&uri=true", path_part)
195 } else {
196 authority_uri.to_string()
197 };
198 let tc = ThreadClonedMDB::from_authority(&ro_uri);
199
200 #[cfg(test)]
203 {
204 tc.with_tls_conn(|_| Ok(()))?;
206 }
207
208 let _ = TABLE_WHITELIST.set(tables.into_iter().collect::<HashSet<_>>());
209 info_ctrl!("init authority knowdb success({}) ", knowdb_conf.display(),);
210 set_provider(Arc::new(tc))
211}