leetcode_api/dao/
mod.rs

1pub mod query;
2pub mod save_info;
3
4use std::future::Future;
5
6use lcode_config::global::G_DATABASE_PATH;
7use miette::{IntoDiagnostic, Result};
8use sea_orm::{
9    sea_query::OnConflict, ActiveModelTrait, ConnectionTrait, Database, DatabaseConnection,
10    EntityTrait, IntoActiveModel, ModelTrait, Schema,
11};
12use tokio::{join, sync::OnceCell};
13use tracing::{debug, error};
14
15use crate::entities::{prelude::*, *};
16
17pub trait InsertToDB: std::marker::Sized {
18    type Value: Into<sea_orm::Value> + Send;
19    type Entity: EntityTrait;
20    type Model: ModelTrait + Default + IntoActiveModel<Self::ActiveModel>;
21    type ActiveModel: ActiveModelTrait<Entity = Self::Entity>
22        + std::marker::Send
23        + std::convert::From<Self::Model>;
24
25    fn to_model(&self, _info: Self::Value) -> Self::Model {
26        Self::Model::default()
27    }
28    /// Insert with extra logic
29    ///
30    /// * `_info`: extra info
31    fn insert_to_db(&mut self, _info: Self::Value) -> impl Future<Output = ()> + Send {
32        async {}
33    }
34    fn to_activemodel(&self, value: Self::Value) -> Self::ActiveModel {
35        self.to_model(value).into_active_model()
36    }
37    /// Insert One
38    ///
39    /// * `_info`: extra info
40    fn insert_one(&self, info: Self::Value) -> impl Future<Output = ()> + Send {
41        let pat = self.to_activemodel(info);
42        async {
43            if let Err(err) = Self::Entity::insert(pat)
44                .on_conflict(Self::on_conflict())
45                .exec(glob_db().await)
46                .await
47            {
48                error!("{}", err);
49            }
50        }
51    }
52    fn insert_many(value: Vec<Self::ActiveModel>) -> impl Future<Output = ()> + Send {
53        async {
54            if let Err(err) = Self::Entity::insert_many(value)
55                .on_conflict(Self::on_conflict())
56                .exec(glob_db().await)
57                .await
58            {
59                error!("{}", err);
60            }
61        }
62    }
63    fn on_conflict() -> OnConflict;
64}
65
66pub static DB: OnceCell<DatabaseConnection> = OnceCell::const_new();
67/// # Initialize the db connection
68pub async fn glob_db() -> &'static DatabaseConnection {
69    DB.get_or_init(|| async {
70        let db = conn_db().await.unwrap_or_default();
71
72        let builder = db.get_database_backend();
73        let schema = Schema::new(builder);
74
75        let stmt_index = builder.build(
76            schema
77                .create_table_from_entity(Index)
78                .if_not_exists(),
79        );
80        let stmt_detail = builder.build(
81            schema
82                .create_table_from_entity(Detail)
83                .if_not_exists(),
84        );
85
86        let stmt_newindexdb = builder.build(
87            schema
88                .create_table_from_entity(NewIndexDB)
89                .if_not_exists(),
90        );
91        let stmt_topictagsdb = builder.build(
92            schema
93                .create_table_from_entity(TopicTagsDB)
94                .if_not_exists(),
95        );
96        let stmt_qstagdb = builder.build(
97            schema
98                .create_table_from_entity(QsTagDB)
99                .if_not_exists(),
100        );
101
102        // new table
103        let res = join!(
104            db.execute(stmt_index),
105            db.execute(stmt_detail),
106            db.execute(stmt_newindexdb),
107            db.execute(stmt_topictagsdb),
108            db.execute(stmt_qstagdb)
109        );
110
111        macro_rules! log_errors {
112            ($($res:expr),*) => {
113                $(
114                    if let Err(err) = $res {
115                        error!("{}", err);
116                    }
117                )*
118            };
119        }
120        log_errors!(res.0, res.1, res.2, res.3, res.4);
121
122        db
123    })
124    .await
125}
126/// get database connection
127async fn conn_db() -> Result<DatabaseConnection> {
128    let db_conn_str = format!("sqlite:{}?mode=rwc", G_DATABASE_PATH.to_string_lossy());
129    debug!("database dir: {}", &db_conn_str);
130
131    Database::connect(db_conn_str)
132        .await
133        .into_diagnostic()
134}