Skip to main content

wp_knowledge/
loader.rs

1use std::fs;
2use std::io::Read;
3use std::path::{Path, PathBuf};
4
5use orion_conf::EnvTomlLoad;
6use serde::Deserialize;
7use wp_log::info_ctrl;
8
9use crate::error::{KnowReason, KnowledgeResult};
10use crate::mem::memdb::MemDB;
11use orion_error::OperationContext;
12use orion_error::conversion::{SourceErr, SourceRawErr, ToStructError};
13use orion_variate::EnvDict;
14use rusqlite::OpenFlags;
15
16/// V2 KnowDB 配置:目录式 + 外置 SQL。仅支持单一数据文件:`<table_dir>/data.csv`,
17/// 或通过 `tables[n].data_file` 相对 `<table_dir>` 指定。
18#[derive(Debug, Deserialize)]
19pub struct KnowDbConf {
20    pub version: u32,
21    #[serde(default = "default_dot")]
22    pub base_dir: String,
23    #[serde(default)]
24    pub default: OptLoadSpec,
25    #[serde(default)]
26    pub csv: CsvSpec,
27    #[serde(default)]
28    pub cache: CacheSpec,
29    #[serde(default)]
30    pub provider: Option<ProviderSpec>,
31    #[serde(default)]
32    pub tables: Vec<TableSpec>,
33}
34
35#[derive(Debug, Clone, Deserialize)]
36pub struct CacheSpec {
37    #[serde(default = "default_true")]
38    pub enabled: bool,
39    #[serde(default = "default_result_cache_capacity")]
40    pub capacity: usize,
41    #[serde(default = "default_result_cache_ttl_ms")]
42    pub ttl_ms: u64,
43}
44
45impl Default for CacheSpec {
46    fn default() -> Self {
47        Self {
48            enabled: default_true(),
49            capacity: default_result_cache_capacity(),
50            ttl_ms: default_result_cache_ttl_ms(),
51        }
52    }
53}
54
55#[derive(Debug, Clone, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum ProviderKind {
58    SqliteAuthority,
59    Postgres,
60    Mysql,
61}
62
63#[derive(Debug, Clone, Deserialize)]
64pub struct ProviderSpec {
65    pub kind: ProviderKind,
66    pub connection_uri: String,
67    #[serde(default)]
68    pub pool_size: Option<u32>,
69    #[serde(default)]
70    pub min_connections: Option<u32>,
71    #[serde(default)]
72    pub acquire_timeout_ms: Option<u64>,
73    #[serde(default)]
74    pub idle_timeout_ms: Option<u64>,
75    #[serde(default)]
76    pub max_lifetime_ms: Option<u64>,
77}
78
79#[derive(Debug, Clone, Deserialize)]
80pub struct OptLoadSpec {
81    #[serde(default = "default_true")]
82    pub transaction: bool,
83    #[serde(default = "default_batch")]
84    pub batch_size: usize,
85    #[serde(default = "default_on_error")]
86    pub on_error: OnError,
87}
88impl Default for OptLoadSpec {
89    fn default() -> Self {
90        Self {
91            transaction: true,
92            batch_size: default_batch(),
93            on_error: default_on_error(),
94        }
95    }
96}
97
98#[derive(Debug, Clone, Deserialize, Default)]
99#[serde(rename_all = "lowercase")]
100pub enum OnError {
101    #[default]
102    Fail,
103    Skip,
104}
105
106#[derive(Debug, Clone, Deserialize)]
107pub struct CsvSpec {
108    #[serde(default = "default_true")]
109    pub has_header: bool,
110    #[serde(default = "default_comma")]
111    pub delimiter: String,
112    #[serde(default = "default_utf8")]
113    pub encoding: String,
114    #[serde(default = "default_true")]
115    pub trim: bool,
116}
117impl Default for CsvSpec {
118    fn default() -> Self {
119        CsvSpec {
120            has_header: true,
121            delimiter: ",".into(),
122            encoding: "utf-8".into(),
123            trim: true,
124        }
125    }
126}
127
128#[derive(Debug, Clone, Deserialize)]
129pub struct TableSpec {
130    pub name: String,
131    #[serde(default)]
132    pub dir: Option<String>,
133    #[serde(default)]
134    pub data_file: Option<String>,
135    pub columns: ColumnsSpec,
136    #[serde(default)]
137    pub expected_rows: RowExpect,
138    #[serde(default = "default_true")]
139    pub enabled: bool,
140}
141
142#[derive(Debug, Clone, Deserialize)]
143pub struct ColumnsSpec {
144    #[serde(default)]
145    pub by_header: Vec<String>,
146    #[serde(default)]
147    pub by_index: Vec<usize>,
148}
149
150#[derive(Debug, Clone, Deserialize, Default)]
151pub struct RowExpect {
152    pub min: Option<usize>,
153    pub max: Option<usize>,
154}
155
156const fn default_true() -> bool {
157    true
158}
159const fn default_batch() -> usize {
160    2000
161}
162fn default_comma() -> String {
163    ",".to_string()
164}
165fn default_utf8() -> String {
166    "utf-8".to_string()
167}
168fn default_on_error() -> OnError {
169    OnError::Fail
170}
171fn default_dot() -> String {
172    ".".to_string()
173}
174const fn default_result_cache_capacity() -> usize {
175    1024
176}
177const fn default_result_cache_ttl_ms() -> u64 {
178    30_000
179}
180
181/// 读取文本文件,返回字符串
182fn read_to_string(path: &Path) -> KnowledgeResult<String> {
183    let mut f = fs::File::open(path).source_raw_err(KnowReason::from_res(), "source error")?;
184    let mut buf = String::new();
185    f.read_to_string(&mut buf)
186        .source_raw_err(KnowReason::from_res(), "source error")?;
187    Ok(buf)
188}
189
190fn replace_table(sql: &str, table: &str) -> String {
191    sql.replace("{table}", table)
192}
193
194fn join_rel(base: &Path, rel: &str) -> PathBuf {
195    let p = Path::new(rel);
196    if p.is_absolute() {
197        p.to_path_buf()
198    } else {
199        base.join(p)
200    }
201}
202
203pub fn build_authority_from_knowdb(
204    root: &Path,
205    conf_path: &Path,
206    authority_uri: &str,
207    dict: &EnvDict,
208) -> KnowledgeResult<Vec<String>> {
209    let mut opx = OperationContext::doing("build authority from knowdb").with_auto_log();
210    // 1) 解析配置与 base_dir
211    let (conf, conf_abs, base_dir) = parse_knowdb_conf(root, conf_path, dict)?;
212    opx.record("conf", conf_abs.display());
213    opx.record("base_dir", base_dir.display());
214    // 2) 打开权威库
215    let db = open_authority(authority_uri)?;
216    // 3) 逐表加载(按配置顺序);不再处理显式依赖
217    let mut loaded_names = Vec::new();
218    for t in &conf.tables {
219        if !t.enabled {
220            continue;
221        }
222        load_one_table(&db, &base_dir, t, &conf.csv, &conf.default)?;
223        info_ctrl!("load table {} suc!", base_dir.display(),);
224        loaded_names.push(t.name.clone());
225    }
226    opx.mark_suc();
227    Ok(loaded_names)
228}
229
230pub fn parse_knowdb_conf(
231    root: &Path,
232    conf_path: &Path,
233    dict: &EnvDict,
234) -> KnowledgeResult<(KnowDbConf, PathBuf, PathBuf)> {
235    let conf_abs = if conf_path.is_absolute() {
236        conf_path.to_path_buf()
237    } else {
238        root.join(conf_path)
239    };
240    let conf_txt = read_to_string(&conf_abs)?;
241    let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(&conf_txt, dict)
242        .source_err(KnowReason::from_conf(), "parse knowdb config")?;
243    if conf.version != 2 {
244        return Err(KnowReason::from_conf()
245            .to_err()
246            .with_detail("unsupported knowdb.version"));
247    }
248    let conf_dir = conf_abs.parent().unwrap_or_else(|| Path::new("."));
249    let base_dir = join_rel(conf_dir, &conf.base_dir);
250    Ok((conf, conf_abs, base_dir))
251}
252
253fn open_authority(authority_uri: &str) -> KnowledgeResult<MemDB> {
254    ensure_parent_dir_for_file_uri(authority_uri);
255    let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
256        | OpenFlags::SQLITE_OPEN_CREATE
257        | OpenFlags::SQLITE_OPEN_URI;
258    let db = MemDB::new_file(authority_uri, 1, flags)?;
259    // 预注册内置 UDF 至权威库连接(注意:连接池可能返回不同连接,导入时也会再次注册)
260    let _ = db.with_conn(|conn| {
261        let _ = crate::sqlite_ext::register_builtin(conn);
262        Ok::<(), anyhow::Error>(())
263    });
264    Ok(db)
265}
266
267/// Kahn 拓扑排序:返回按依赖顺序的表索引列表。
268/// no topo_sort_tables: V2 简化版按配置顺序加载
269fn ensure_parent_dir_for_file_uri(uri: &str) {
270    if let Some(rest) = uri.strip_prefix("file:") {
271        let path_part = rest.split('?').next().unwrap_or(rest);
272        let p = Path::new(path_part);
273        if let Some(parent) = p.parent() {
274            let _ = fs::create_dir_all(parent);
275        }
276    }
277}
278
279fn load_one_table(
280    db: &MemDB,
281    base_dir: &Path,
282    t: &TableSpec,
283    csvd: &CsvSpec,
284    load: &OptLoadSpec,
285) -> KnowledgeResult<()> {
286    // 目录与必须文件
287    let mut opx = OperationContext::doing("load table to kdb")
288        .with_auto_log()
289        .with_mod_path("ctrl");
290    let dir_name: &str = t.dir.as_deref().unwrap_or(&t.name);
291    let table_dir = base_dir.join(dir_name);
292    opx.record("table_dir", table_dir.display());
293    let create_sql = replace_table(&read_to_string(&table_dir.join("create.sql"))?, &t.name);
294    let insert_sql = replace_table(&read_to_string(&table_dir.join("insert.sql"))?, &t.name);
295    let clean_path = table_dir.join("clean.sql");
296    let clean_sql = if clean_path.exists() {
297        replace_table(&read_to_string(&clean_path)?, &t.name)
298    } else {
299        format!("DELETE FROM {}", &t.name)
300    };
301
302    // 建表与清理
303    db.with_conn(|conn| {
304        // 注册内置 UDF(导入连接)
305        let _ = crate::sqlite_ext::register_builtin(conn);
306        conn.execute_batch(&create_sql)?;
307        conn.execute_batch(&clean_sql)?;
308        Ok::<(), anyhow::Error>(())
309    })
310    .source_err(KnowReason::from_res(), "prepare authority table")?;
311
312    // 数据源
313    let data_path = match &t.data_file {
314        Some(rel) => join_rel(&table_dir, rel),
315        None => table_dir.join("data.csv"),
316    };
317    if !data_path.exists() {
318        return Err(KnowReason::from_conf()
319            .to_err()
320            .with_detail("data.csv not found"));
321    }
322    opx.record("data_path", data_path.display());
323
324    // CSV 解析器
325    let mut rdr = build_csv_reader(csvd, &data_path)?;
326
327    // 列映射
328    let col_indices: Vec<usize> = if !t.columns.by_header.is_empty() {
329        let headers = rdr
330            .headers()
331            .source_raw_err(KnowReason::from_res(), "source error")?;
332        select_indices_by_header(headers, &t.columns.by_header)?
333    } else if !t.columns.by_index.is_empty() {
334        t.columns.by_index.clone()
335    } else {
336        return Err(KnowReason::from_conf()
337            .to_err()
338            .with_detail("columns mapping required"));
339    };
340
341    // 导入(分批事务)
342    let mut inserted: usize = 0;
343    let mut bad: usize = 0;
344    let mut batch_left = load.batch_size.max(1);
345    db.with_conn(|conn| {
346        // 注册内置 UDF(用于 INSERT 绑定表达式)
347        let _ = crate::sqlite_ext::register_builtin(conn);
348        let mut tx = if load.transaction {
349            Some(conn.unchecked_transaction()?)
350        } else {
351            None
352        };
353        let mut stmt = conn.prepare(&insert_sql)?;
354        for rec in rdr.into_records() {
355            match rec {
356                Ok(record) => {
357                    let refs = extract_row_refs(&record, &col_indices, &mut bad, load)?;
358                    if let Some(refs) = refs {
359                        stmt.execute(rusqlite::params_from_iter(refs))?;
360                        inserted += 1;
361                        if load.transaction {
362                            batch_left -= 1;
363                            if batch_left == 0 {
364                                tx.take().unwrap().commit()?;
365                                tx = Some(conn.unchecked_transaction()?);
366                                batch_left = load.batch_size.max(1);
367                            }
368                        }
369                    }
370                }
371                Err(_e) => {
372                    if matches!(load.on_error, OnError::Skip) {
373                        bad += 1;
374                        continue;
375                    } else {
376                        anyhow::bail!("csv record parse error");
377                    }
378                }
379            }
380        }
381        if let Some(tx) = tx {
382            tx.commit()?;
383        }
384        Ok::<(), anyhow::Error>(())
385    })
386    .source_err(KnowReason::from_res(), "load authority table data")?;
387
388    // 行数校验
389    if let Some(min) = t.expected_rows.min
390        && inserted < min
391    {
392        return Err(KnowReason::from_conf()
393            .to_err()
394            .with_detail("table data less"));
395    }
396    if let Some(max) = t.expected_rows.max
397        && inserted > max
398    {
399        wp_log::warn_kdb!(
400            "table {} loaded rows {} exceed max {}",
401            &t.name,
402            inserted,
403            max
404        );
405    }
406    if bad > 0 {
407        wp_log::warn_kdb!("table {} skipped {} bad rows (on_error=skip)", &t.name, bad);
408    }
409    opx.mark_suc();
410    Ok(())
411}
412
413fn build_csv_reader(
414    csvd: &CsvSpec,
415    data_path: &Path,
416) -> KnowledgeResult<csv::Reader<std::fs::File>> {
417    if csvd.encoding.to_lowercase() != "utf-8" {
418        return Err(KnowReason::from_conf()
419            .to_err()
420            .with_detail("only utf-8 csv is supported"));
421    }
422    let mut rdr_b = csv::ReaderBuilder::new();
423    rdr_b.has_headers(csvd.has_header);
424    if csvd.delimiter.len() == 1 {
425        rdr_b.delimiter(csvd.delimiter.as_bytes()[0]);
426    }
427    if csvd.trim {
428        rdr_b.trim(csv::Trim::All);
429    }
430    rdr_b
431        .from_path(data_path)
432        .source_raw_err(KnowReason::from_res(), "source error")
433}
434
435fn select_indices_by_header(
436    headers: &csv::StringRecord,
437    wanted: &[String],
438) -> KnowledgeResult<Vec<usize>> {
439    let mut out = Vec::with_capacity(wanted.len());
440    for name in wanted {
441        let pos = headers.iter().position(|h| h == name).ok_or_else(|| {
442            KnowReason::from_conf()
443                .to_err()
444                .with_detail("header not found")
445        })?;
446        out.push(pos);
447    }
448    Ok(out)
449}
450
451fn extract_row_refs<'a>(
452    record: &'a csv::StringRecord,
453    col_indices: &[usize],
454    bad: &mut usize,
455    load: &OptLoadSpec,
456) -> anyhow::Result<Option<Vec<&'a str>>> {
457    let mut vs: Vec<&str> = Vec::with_capacity(col_indices.len());
458    for &idx in col_indices {
459        if idx >= record.len() {
460            if matches!(load.on_error, OnError::Skip) {
461                *bad += 1;
462                return Ok(None);
463            } else {
464                anyhow::bail!("missing column at index {}", idx);
465            }
466        }
467        vs.push(record.get(idx).unwrap_or(""));
468    }
469    Ok(Some(vs))
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn parse_external_provider_spec() {
478        let dict = EnvDict::default();
479        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
480            r#"
481version = 2
482
483[provider]
484kind = "postgres"
485connection_uri = "postgres://demo:demo@127.0.0.1/demo"
486min_connections = 2
487acquire_timeout_ms = 1500
488idle_timeout_ms = 30000
489max_lifetime_ms = 60000
490"#,
491            &dict,
492        )
493        .expect("parse knowdb with provider");
494
495        assert!(conf.tables.is_empty());
496        let provider = conf.provider.expect("provider");
497        assert!(matches!(provider.kind, ProviderKind::Postgres));
498        assert_eq!(
499            provider.connection_uri,
500            "postgres://demo:demo@127.0.0.1/demo"
501        );
502        assert_eq!(provider.min_connections, Some(2));
503        assert_eq!(provider.acquire_timeout_ms, Some(1500));
504        assert_eq!(provider.idle_timeout_ms, Some(30000));
505        assert_eq!(provider.max_lifetime_ms, Some(60000));
506    }
507
508    #[test]
509    fn parse_mysql_provider_spec() {
510        let dict = EnvDict::default();
511        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
512            r#"
513version = 2
514
515[provider]
516kind = "mysql"
517connection_uri = "mysql://demo:demo@127.0.0.1:3306/demo"
518pool_size = 12
519min_connections = 3
520acquire_timeout_ms = 2500
521idle_timeout_ms = 45000
522max_lifetime_ms = 120000
523"#,
524            &dict,
525        )
526        .expect("parse knowdb with mysql provider");
527
528        let provider = conf.provider.expect("provider");
529        assert!(matches!(provider.kind, ProviderKind::Mysql));
530        assert_eq!(
531            provider.connection_uri,
532            "mysql://demo:demo@127.0.0.1:3306/demo"
533        );
534        assert_eq!(provider.pool_size, Some(12));
535        assert_eq!(provider.min_connections, Some(3));
536        assert_eq!(provider.acquire_timeout_ms, Some(2500));
537        assert_eq!(provider.idle_timeout_ms, Some(45000));
538        assert_eq!(provider.max_lifetime_ms, Some(120000));
539    }
540
541    #[test]
542    fn parse_cache_spec_with_defaults() {
543        let dict = EnvDict::default();
544        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
545            r#"
546version = 2
547"#,
548            &dict,
549        )
550        .expect("parse knowdb with default cache spec");
551
552        assert!(conf.cache.enabled);
553        assert_eq!(conf.cache.capacity, 1024);
554        assert_eq!(conf.cache.ttl_ms, 30_000);
555    }
556
557    #[test]
558    fn parse_cache_spec_from_toml() {
559        let dict = EnvDict::default();
560        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
561            r#"
562version = 2
563
564[cache]
565enabled = false
566capacity = 256
567ttl_ms = 1500
568"#,
569            &dict,
570        )
571        .expect("parse knowdb with cache spec");
572
573        assert!(!conf.cache.enabled);
574        assert_eq!(conf.cache.capacity, 256);
575        assert_eq!(conf.cache.ttl_ms, 1500);
576    }
577}