Skip to main content

wp_knowledge/
loader.rs

1use std::fs;
2use std::io::Read;
3use std::path::{Path, PathBuf};
4
5use orion_conf::EnvTomlLoad;
6use serde::Deserialize;
7use wp_log::info_ctrl;
8
9use crate::error::{KnowledgeResult, Reason};
10use crate::mem::memdb::MemDB;
11use orion_error::compat_traits::ErrorOweBase;
12use orion_error::conversion::ToStructError;
13use orion_error::runtime::ContextRecord;
14use orion_error::{OperationContext, UvsFrom};
15use orion_variate::EnvDict;
16use rusqlite::OpenFlags;
17
18/// V2 KnowDB 配置:目录式 + 外置 SQL。仅支持单一数据文件:`<table_dir>/data.csv`,
19/// 或通过 `tables[n].data_file` 相对 `<table_dir>` 指定。
20#[derive(Debug, Deserialize)]
21pub struct KnowDbConf {
22    pub version: u32,
23    #[serde(default = "default_dot")]
24    pub base_dir: String,
25    #[serde(default)]
26    pub default: OptLoadSpec,
27    #[serde(default)]
28    pub csv: CsvSpec,
29    #[serde(default)]
30    pub cache: CacheSpec,
31    #[serde(default)]
32    pub provider: Option<ProviderSpec>,
33    #[serde(default)]
34    pub tables: Vec<TableSpec>,
35}
36
37#[derive(Debug, Clone, Deserialize)]
38pub struct CacheSpec {
39    #[serde(default = "default_true")]
40    pub enabled: bool,
41    #[serde(default = "default_result_cache_capacity")]
42    pub capacity: usize,
43    #[serde(default = "default_result_cache_ttl_ms")]
44    pub ttl_ms: u64,
45}
46
47impl Default for CacheSpec {
48    fn default() -> Self {
49        Self {
50            enabled: default_true(),
51            capacity: default_result_cache_capacity(),
52            ttl_ms: default_result_cache_ttl_ms(),
53        }
54    }
55}
56
57#[derive(Debug, Clone, Deserialize)]
58#[serde(rename_all = "snake_case")]
59pub enum ProviderKind {
60    SqliteAuthority,
61    Postgres,
62    Mysql,
63}
64
65#[derive(Debug, Clone, Deserialize)]
66pub struct ProviderSpec {
67    pub kind: ProviderKind,
68    pub connection_uri: String,
69    #[serde(default)]
70    pub pool_size: Option<u32>,
71    #[serde(default)]
72    pub min_connections: Option<u32>,
73    #[serde(default)]
74    pub acquire_timeout_ms: Option<u64>,
75    #[serde(default)]
76    pub idle_timeout_ms: Option<u64>,
77    #[serde(default)]
78    pub max_lifetime_ms: Option<u64>,
79}
80
81#[derive(Debug, Clone, Deserialize)]
82pub struct OptLoadSpec {
83    #[serde(default = "default_true")]
84    pub transaction: bool,
85    #[serde(default = "default_batch")]
86    pub batch_size: usize,
87    #[serde(default = "default_on_error")]
88    pub on_error: OnError,
89}
90impl Default for OptLoadSpec {
91    fn default() -> Self {
92        Self {
93            transaction: true,
94            batch_size: default_batch(),
95            on_error: default_on_error(),
96        }
97    }
98}
99
100#[derive(Debug, Clone, Deserialize, Default)]
101#[serde(rename_all = "lowercase")]
102pub enum OnError {
103    #[default]
104    Fail,
105    Skip,
106}
107
108#[derive(Debug, Clone, Deserialize)]
109pub struct CsvSpec {
110    #[serde(default = "default_true")]
111    pub has_header: bool,
112    #[serde(default = "default_comma")]
113    pub delimiter: String,
114    #[serde(default = "default_utf8")]
115    pub encoding: String,
116    #[serde(default = "default_true")]
117    pub trim: bool,
118}
119impl Default for CsvSpec {
120    fn default() -> Self {
121        CsvSpec {
122            has_header: true,
123            delimiter: ",".into(),
124            encoding: "utf-8".into(),
125            trim: true,
126        }
127    }
128}
129
130#[derive(Debug, Clone, Deserialize)]
131pub struct TableSpec {
132    pub name: String,
133    #[serde(default)]
134    pub dir: Option<String>,
135    #[serde(default)]
136    pub data_file: Option<String>,
137    pub columns: ColumnsSpec,
138    #[serde(default)]
139    pub expected_rows: RowExpect,
140    #[serde(default = "default_true")]
141    pub enabled: bool,
142}
143
144#[derive(Debug, Clone, Deserialize)]
145pub struct ColumnsSpec {
146    #[serde(default)]
147    pub by_header: Vec<String>,
148    #[serde(default)]
149    pub by_index: Vec<usize>,
150}
151
152#[derive(Debug, Clone, Deserialize, Default)]
153pub struct RowExpect {
154    pub min: Option<usize>,
155    pub max: Option<usize>,
156}
157
158const fn default_true() -> bool {
159    true
160}
161const fn default_batch() -> usize {
162    2000
163}
164fn default_comma() -> String {
165    ",".to_string()
166}
167fn default_utf8() -> String {
168    "utf-8".to_string()
169}
170fn default_on_error() -> OnError {
171    OnError::Fail
172}
173fn default_dot() -> String {
174    ".".to_string()
175}
176const fn default_result_cache_capacity() -> usize {
177    1024
178}
179const fn default_result_cache_ttl_ms() -> u64 {
180    30_000
181}
182
183/// 读取文本文件,返回字符串
184fn read_to_string(path: &Path) -> KnowledgeResult<String> {
185    let mut f = fs::File::open(path).owe(Reason::from_res())?;
186    let mut buf = String::new();
187    f.read_to_string(&mut buf).owe(Reason::from_res())?;
188    Ok(buf)
189}
190
191fn replace_table(sql: &str, table: &str) -> String {
192    sql.replace("{table}", table)
193}
194
195fn join_rel(base: &Path, rel: &str) -> PathBuf {
196    let p = Path::new(rel);
197    if p.is_absolute() {
198        p.to_path_buf()
199    } else {
200        base.join(p)
201    }
202}
203
204pub fn build_authority_from_knowdb(
205    root: &Path,
206    conf_path: &Path,
207    authority_uri: &str,
208    dict: &EnvDict,
209) -> KnowledgeResult<Vec<String>> {
210    let mut opx = OperationContext::doing("build authority from knowdb").with_auto_log();
211    // 1) 解析配置与 base_dir
212    let (conf, conf_abs, base_dir) = parse_knowdb_conf(root, conf_path, dict)?;
213    opx.record("conf", &conf_abs);
214    opx.record("base_dir", &base_dir);
215    // 2) 打开权威库
216    let db = open_authority(authority_uri)?;
217    // 3) 逐表加载(按配置顺序);不再处理显式依赖
218    let mut loaded_names = Vec::new();
219    for t in &conf.tables {
220        if !t.enabled {
221            continue;
222        }
223        load_one_table(&db, &base_dir, t, &conf.csv, &conf.default)?;
224        info_ctrl!("load table {} suc!", base_dir.display(),);
225        loaded_names.push(t.name.clone());
226    }
227    opx.mark_suc();
228    Ok(loaded_names)
229}
230
231pub fn parse_knowdb_conf(
232    root: &Path,
233    conf_path: &Path,
234    dict: &EnvDict,
235) -> KnowledgeResult<(KnowDbConf, PathBuf, PathBuf)> {
236    let conf_abs = if conf_path.is_absolute() {
237        conf_path.to_path_buf()
238    } else {
239        root.join(conf_path)
240    };
241    let conf_txt = read_to_string(&conf_abs)?;
242    let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(&conf_txt, dict)
243        .owe(Reason::from_conf())?;
244    if conf.version != 2 {
245        return Err(Reason::from_conf()
246            .to_err()
247            .with_detail("unsupported knowdb.version"));
248    }
249    let conf_dir = conf_abs.parent().unwrap_or_else(|| Path::new("."));
250    let base_dir = join_rel(conf_dir, &conf.base_dir);
251    Ok((conf, conf_abs, base_dir))
252}
253
254fn open_authority(authority_uri: &str) -> KnowledgeResult<MemDB> {
255    ensure_parent_dir_for_file_uri(authority_uri);
256    let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
257        | OpenFlags::SQLITE_OPEN_CREATE
258        | OpenFlags::SQLITE_OPEN_URI;
259    let db = MemDB::new_file(authority_uri, 1, flags)?;
260    // 预注册内置 UDF 至权威库连接(注意:连接池可能返回不同连接,导入时也会再次注册)
261    let _ = db.with_conn(|conn| {
262        let _ = crate::sqlite_ext::register_builtin(conn);
263        Ok::<(), anyhow::Error>(())
264    });
265    Ok(db)
266}
267
268/// Kahn 拓扑排序:返回按依赖顺序的表索引列表。
269/// no topo_sort_tables: V2 简化版按配置顺序加载
270fn ensure_parent_dir_for_file_uri(uri: &str) {
271    if let Some(rest) = uri.strip_prefix("file:") {
272        let path_part = rest.split('?').next().unwrap_or(rest);
273        let p = Path::new(path_part);
274        if let Some(parent) = p.parent() {
275            let _ = fs::create_dir_all(parent);
276        }
277    }
278}
279
280fn load_one_table(
281    db: &MemDB,
282    base_dir: &Path,
283    t: &TableSpec,
284    csvd: &CsvSpec,
285    load: &OptLoadSpec,
286) -> KnowledgeResult<()> {
287    // 目录与必须文件
288    let mut opx = OperationContext::doing("load table to kdb")
289        .with_auto_log()
290        .with_mod_path("ctrl");
291    let dir_name: &str = t.dir.as_deref().unwrap_or(&t.name);
292    let table_dir = base_dir.join(dir_name);
293    opx.record("table_dir", &table_dir);
294    let create_sql = replace_table(&read_to_string(&table_dir.join("create.sql"))?, &t.name);
295    let insert_sql = replace_table(&read_to_string(&table_dir.join("insert.sql"))?, &t.name);
296    let clean_path = table_dir.join("clean.sql");
297    let clean_sql = if clean_path.exists() {
298        replace_table(&read_to_string(&clean_path)?, &t.name)
299    } else {
300        format!("DELETE FROM {}", &t.name)
301    };
302
303    // 建表与清理
304    db.with_conn(|conn| {
305        // 注册内置 UDF(导入连接)
306        let _ = crate::sqlite_ext::register_builtin(conn);
307        conn.execute_batch(&create_sql)?;
308        conn.execute_batch(&clean_sql)?;
309        Ok::<(), anyhow::Error>(())
310    })
311    .owe(Reason::from_res())?;
312
313    // 数据源
314    let data_path = match &t.data_file {
315        Some(rel) => join_rel(&table_dir, rel),
316        None => table_dir.join("data.csv"),
317    };
318    if !data_path.exists() {
319        return Err(Reason::from_conf()
320            .to_err()
321            .with_detail("data.csv not found"));
322    }
323    opx.record("data_path", &data_path);
324
325    // CSV 解析器
326    let mut rdr = build_csv_reader(csvd, &data_path)?;
327
328    // 列映射
329    let col_indices: Vec<usize> = if !t.columns.by_header.is_empty() {
330        let headers = rdr.headers().owe(Reason::from_res())?;
331        select_indices_by_header(headers, &t.columns.by_header)?
332    } else if !t.columns.by_index.is_empty() {
333        t.columns.by_index.clone()
334    } else {
335        return Err(Reason::from_conf()
336            .to_err()
337            .with_detail("columns mapping required"));
338    };
339
340    // 导入(分批事务)
341    let mut inserted: usize = 0;
342    let mut bad: usize = 0;
343    let mut batch_left = load.batch_size.max(1);
344    db.with_conn(|conn| {
345        // 注册内置 UDF(用于 INSERT 绑定表达式)
346        let _ = crate::sqlite_ext::register_builtin(conn);
347        let mut tx = if load.transaction {
348            Some(conn.unchecked_transaction()?)
349        } else {
350            None
351        };
352        let mut stmt = conn.prepare(&insert_sql)?;
353        for rec in rdr.into_records() {
354            match rec {
355                Ok(record) => {
356                    let refs = extract_row_refs(&record, &col_indices, &mut bad, load)?;
357                    if let Some(refs) = refs {
358                        stmt.execute(rusqlite::params_from_iter(refs))?;
359                        inserted += 1;
360                        if load.transaction {
361                            batch_left -= 1;
362                            if batch_left == 0 {
363                                tx.take().unwrap().commit()?;
364                                tx = Some(conn.unchecked_transaction()?);
365                                batch_left = load.batch_size.max(1);
366                            }
367                        }
368                    }
369                }
370                Err(_e) => {
371                    if matches!(load.on_error, OnError::Skip) {
372                        bad += 1;
373                        continue;
374                    } else {
375                        anyhow::bail!("csv record parse error");
376                    }
377                }
378            }
379        }
380        if let Some(tx) = tx {
381            tx.commit()?;
382        }
383        Ok::<(), anyhow::Error>(())
384    })
385    .owe(Reason::from_res())?;
386
387    // 行数校验
388    if let Some(min) = t.expected_rows.min
389        && inserted < min
390    {
391        return Err(Reason::from_conf().to_err().with_detail("table data less"));
392    }
393    if let Some(max) = t.expected_rows.max
394        && inserted > max
395    {
396        wp_log::warn_kdb!(
397            "table {} loaded rows {} exceed max {}",
398            &t.name,
399            inserted,
400            max
401        );
402    }
403    if bad > 0 {
404        wp_log::warn_kdb!("table {} skipped {} bad rows (on_error=skip)", &t.name, bad);
405    }
406    opx.mark_suc();
407    Ok(())
408}
409
410fn build_csv_reader(
411    csvd: &CsvSpec,
412    data_path: &Path,
413) -> KnowledgeResult<csv::Reader<std::fs::File>> {
414    if csvd.encoding.to_lowercase() != "utf-8" {
415        return Err(Reason::from_conf()
416            .to_err()
417            .with_detail("only utf-8 csv is supported"));
418    }
419    let mut rdr_b = csv::ReaderBuilder::new();
420    rdr_b.has_headers(csvd.has_header);
421    if csvd.delimiter.len() == 1 {
422        rdr_b.delimiter(csvd.delimiter.as_bytes()[0]);
423    }
424    if csvd.trim {
425        rdr_b.trim(csv::Trim::All);
426    }
427    rdr_b.from_path(data_path).owe(Reason::from_res())
428}
429
430fn select_indices_by_header(
431    headers: &csv::StringRecord,
432    wanted: &[String],
433) -> KnowledgeResult<Vec<usize>> {
434    let mut out = Vec::with_capacity(wanted.len());
435    for name in wanted {
436        let pos = headers
437            .iter()
438            .position(|h| h == name)
439            .ok_or_else(|| Reason::from_conf().to_err().with_detail("header not found"))?;
440        out.push(pos);
441    }
442    Ok(out)
443}
444
445fn extract_row_refs<'a>(
446    record: &'a csv::StringRecord,
447    col_indices: &[usize],
448    bad: &mut usize,
449    load: &OptLoadSpec,
450) -> anyhow::Result<Option<Vec<&'a str>>> {
451    let mut vs: Vec<&str> = Vec::with_capacity(col_indices.len());
452    for &idx in col_indices {
453        if idx >= record.len() {
454            if matches!(load.on_error, OnError::Skip) {
455                *bad += 1;
456                return Ok(None);
457            } else {
458                anyhow::bail!("missing column at index {}", idx);
459            }
460        }
461        vs.push(record.get(idx).unwrap_or(""));
462    }
463    Ok(Some(vs))
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn parse_external_provider_spec() {
472        let dict = EnvDict::default();
473        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
474            r#"
475version = 2
476
477[provider]
478kind = "postgres"
479connection_uri = "postgres://demo:demo@127.0.0.1/demo"
480min_connections = 2
481acquire_timeout_ms = 1500
482idle_timeout_ms = 30000
483max_lifetime_ms = 60000
484"#,
485            &dict,
486        )
487        .expect("parse knowdb with provider");
488
489        assert!(conf.tables.is_empty());
490        let provider = conf.provider.expect("provider");
491        assert!(matches!(provider.kind, ProviderKind::Postgres));
492        assert_eq!(
493            provider.connection_uri,
494            "postgres://demo:demo@127.0.0.1/demo"
495        );
496        assert_eq!(provider.min_connections, Some(2));
497        assert_eq!(provider.acquire_timeout_ms, Some(1500));
498        assert_eq!(provider.idle_timeout_ms, Some(30000));
499        assert_eq!(provider.max_lifetime_ms, Some(60000));
500    }
501
502    #[test]
503    fn parse_mysql_provider_spec() {
504        let dict = EnvDict::default();
505        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
506            r#"
507version = 2
508
509[provider]
510kind = "mysql"
511connection_uri = "mysql://demo:demo@127.0.0.1:3306/demo"
512pool_size = 12
513min_connections = 3
514acquire_timeout_ms = 2500
515idle_timeout_ms = 45000
516max_lifetime_ms = 120000
517"#,
518            &dict,
519        )
520        .expect("parse knowdb with mysql provider");
521
522        let provider = conf.provider.expect("provider");
523        assert!(matches!(provider.kind, ProviderKind::Mysql));
524        assert_eq!(
525            provider.connection_uri,
526            "mysql://demo:demo@127.0.0.1:3306/demo"
527        );
528        assert_eq!(provider.pool_size, Some(12));
529        assert_eq!(provider.min_connections, Some(3));
530        assert_eq!(provider.acquire_timeout_ms, Some(2500));
531        assert_eq!(provider.idle_timeout_ms, Some(45000));
532        assert_eq!(provider.max_lifetime_ms, Some(120000));
533    }
534
535    #[test]
536    fn parse_cache_spec_with_defaults() {
537        let dict = EnvDict::default();
538        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
539            r#"
540version = 2
541"#,
542            &dict,
543        )
544        .expect("parse knowdb with default cache spec");
545
546        assert!(conf.cache.enabled);
547        assert_eq!(conf.cache.capacity, 1024);
548        assert_eq!(conf.cache.ttl_ms, 30_000);
549    }
550
551    #[test]
552    fn parse_cache_spec_from_toml() {
553        let dict = EnvDict::default();
554        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
555            r#"
556version = 2
557
558[cache]
559enabled = false
560capacity = 256
561ttl_ms = 1500
562"#,
563            &dict,
564        )
565        .expect("parse knowdb with cache spec");
566
567        assert!(!conf.cache.enabled);
568        assert_eq!(conf.cache.capacity, 256);
569        assert_eq!(conf.cache.ttl_ms, 1500);
570    }
571}